In [1]:
import torch

from data import data_helper
from data.data_helper import available_datasets
from models import model_factory
from optimizer.optimizer_helper import get_optim_and_scheduler
from torch.nn import functional as F
from torch import nn

from utils.Logger import Logger

import torchvision
import matplotlib.pyplot as plt

from train_jigsaw import do_training


batch_size = 128
num_workers = 4
jig_classes = 31
class_classes = 31
lr = 0.001
epochs = 30
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
office_list = ["amazon","dslr","webcam"]
source = ["amazon","dslr"]
target = "dslr"
dataloaders = {"train": data_helper.get_dataloader(source, jig_classes, "train"),
           "val": data_helper.get_dataloader(target, jig_classes, "val")}
dataset_sizes = {"train": len(dataloaders["train"].dataset),
                "val": len(dataloaders["val"].dataset)}
print(dataset_sizes)

Using multiple sources
{'train': 3315, 'val': 498}


In [3]:
def get_optim_and_scheduler(network, epochs, lr):
    from torch import optim
    optimizer = optim.SGD(network.parameters(), weight_decay=.0005, momentum=.9, nesterov=True, lr=lr)
    step_size = int(epochs * .9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size)
    print(step_size)
    return optimizer, scheduler

model_ft = model_factory.get_network("caffenet")(jigsaw_classes=jig_classes, classes=class_classes)
model_ft = model_ft.to(device)

optimizer, scheduler = get_optim_and_scheduler(model_ft, epochs, lr)

27


In [4]:
def do_epoch(model, source, target, optimizer, logger, device):
    criterion = nn.CrossEntropyLoss()
    model.train()
    for it, ((data, jig_l, class_l), d_idx) in enumerate(source):
        data, jig_l, class_l = data.to(device), jig_l.to(device), class_l.to(device)
        
        optimizer.zero_grad()
        
        jigsaw_logit, class_logit = model(data)
        jigsaw_loss = criterion(jigsaw_logit, jig_l)
        class_loss = criterion(class_logit[d_idx==0], class_l[d_idx==0])
        _, cls_pred = class_logit.max(dim=1)
        _, jig_pred = jigsaw_logit.max(dim=1)
        loss = class_loss + jigsaw_loss * jig_weight
        
        loss.backward()
        optimizer.step()

        logger.log(it, len(source), {"jigsaw": jigsaw_loss, "class": class_loss},
                  {"jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class":torch.sum(cls_pred == class_l.data).item()},
                  data.shape[0])

    model.eval()
    with torch.no_grad():
        jigsaw_correct = 0
        class_correct = 0
        total = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(target):
            data, jig_l, class_l = data.to(device), jig_l.to(device), class_l.to(device)
            jigsaw_logit, class_logit = model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
            total += data.shape[0]
        logger.log_test({"jigsaw": float(jigsaw_correct) / total,
                         "class": float(class_correct) / total})


def do_training(epochs, model, source, target, optimizer, scheduler, device):
    logger = Logger(epochs)
    for k in range(epochs):
        scheduler.step()
        logger.new_epoch(scheduler.get_lr())
        do_epoch(model, source, target, optimizer, logger, device)
    return logger, model

In [5]:
jig_weight = 0.5
logger, model = do_training(epochs, model_ft, dataloaders["train"], dataloaders["val"], optimizer, scheduler, device)

New epoch - lr: 0.001
0/25 of epoch 1/30 jigsaw : 3.469644, class : 3.453006 - acc jigsaw : 0.039062, class : 0.039062 [bs:128]
10/25 of epoch 1/30 jigsaw : 3.457042, class : 3.156166 - acc jigsaw : 0.062500, class : 0.156250 [bs:128]
20/25 of epoch 1/30 jigsaw : 3.435227, class : 2.585897 - acc jigsaw : 0.054688, class : 0.296875 [bs:128]
Accuracies on target: jigsaw : 0.000000, class : 0.287149
New epoch - lr: 0.001
0/25 of epoch 2/30 jigsaw : 3.404847, class : 2.490460 - acc jigsaw : 0.054688, class : 0.257812 [bs:128]
10/25 of epoch 2/30 jigsaw : 3.279945, class : 2.115338 - acc jigsaw : 0.085938, class : 0.304688 [bs:128]
20/25 of epoch 2/30 jigsaw : 3.244253, class : 1.922949 - acc jigsaw : 0.125000, class : 0.398438 [bs:128]
Accuracies on target: jigsaw : 0.000000, class : 0.353414
New epoch - lr: 0.001
0/25 of epoch 3/30 jigsaw : 3.138083, class : 1.949750 - acc jigsaw : 0.109375, class : 0.382812 [bs:128]
10/25 of epoch 3/30 jigsaw : 2.938991, class : 2.217246 - acc jigsaw : 0

20/25 of epoch 21/30 jigsaw : 0.827653, class : 1.125209 - acc jigsaw : 0.796875, class : 0.625000 [bs:128]
Accuracies on target: jigsaw : 0.004016, class : 0.361446
New epoch - lr: 0.001
0/25 of epoch 22/30 jigsaw : 1.024132, class : 1.090514 - acc jigsaw : 0.742188, class : 0.578125 [bs:128]
10/25 of epoch 22/30 jigsaw : 1.078162, class : 1.533567 - acc jigsaw : 0.656250, class : 0.554688 [bs:128]
20/25 of epoch 22/30 jigsaw : 0.796080, class : 0.926412 - acc jigsaw : 0.773438, class : 0.679688 [bs:128]
Accuracies on target: jigsaw : 0.004016, class : 0.399598
New epoch - lr: 0.001
0/25 of epoch 23/30 jigsaw : 0.914971, class : 1.089072 - acc jigsaw : 0.703125, class : 0.648438 [bs:128]
10/25 of epoch 23/30 jigsaw : 0.839459, class : 0.833350 - acc jigsaw : 0.750000, class : 0.664062 [bs:128]
20/25 of epoch 23/30 jigsaw : 0.785613, class : 0.980675 - acc jigsaw : 0.765625, class : 0.617188 [bs:128]
Accuracies on target: jigsaw : 0.000000, class : 0.363454
New epoch - lr: 0.001
0/25 o

In [None]:
def to_plt(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    return inp

conv1 = models.alexnet(pretrained=True).features[0] #model_ft.features[0]
tmp = conv1.weight.cpu().data
tmp = torchvision.utils.make_grid(tmp,normalize=True)
plt.imshow(to_plt(tmp))
plt.show()

In [None]:
import numpy as np

conv1 = model_ft.features[0]
tmp = conv1.weight.cpu().data
tmp = torchvision.utils.make_grid(tmp,normalize=True)
plt.imshow(to_plt(tmp))
plt.show()

In [None]:
plt.plot(memory["train"], label="train")
plt.plot(memory["val"], label="val")
plt.show()

In [None]:
torch.stack(logger.losses["class"]).cpu().detach().numpy()


# iter_c = iter(train_datasets)

# for x in range(5):
#     tmp = next(iter_c)
#     image = to_plt(tmp[0])
#     plt.imshow(image)
#     plt.show()

In [None]:
from os.path import join, dirname
# from data.JigsawLoader import JigsawTestDataset
import torch
import matplotlib.pyplot as plt
import numpy as np


def to_plt(inp):
    inp = inp.numpy().transpose((1, 2, 0))
#     mean = np.array([0.485, 0.456, 0.406])
#     std = np.array([0.229, 0.224, 0.225])
#     inp = std * inp + mean
#     inp = np.clip(inp, 0, 1)
    return inp

# dataset = JigsawTestDataset("", join('data/txt_lists', 'dslr_train.txt'), patches=False, classes=31)
# test = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)
iter_c = iter(dataloaders["train"])
(tmp, v, c), d = next(iter_c)
for x in range(5):
    image = tmp[x]
    image = torchvision.utils.make_grid(tmp[x],1,normalize=True)
    plt.imshow(to_plt(image))
    plt.show()
    print(v[x],c[x])
    
print(v.max(), v.min())

In [None]:
[d[d==k].shape for k in [0,1]]


In [None]:
from data.JigsawLoader import JigsawDataset
from PIL import Image
import torchvision.transforms as transforms

class JigsawTestDataset(JigsawDataset):
    def __init__(self, *args, **xargs):
        super().__init__(*args, **xargs)
        self._augment_tile = transforms.Compose([
#             transforms.RandomCrop(64),
            transforms.Resize((75, 75), Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
    def __getitem__(self, index):
        framename = self.data_path + '/' + self.names[index]
        img = Image.open(framename).convert('RGB')
        img = self._image_transformer(img)

        w = float(img.size[0]) / self.grid_size
        n_grids = self.grid_size ** 2
        tiles = [None] * n_grids
        for n in range(n_grids):
            y = int(n / self.grid_size)
            x = n % self.grid_size
            tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w])
            tile = self._augment_tile(tile)
            tiles[n] = tile

        data = torch.stack(tiles, 0)
        return self.returnFunc(data), 0, int(self.labels[index])


In [None]:
dataset = JigsawTestDataset("", join('data/txt_lists', 'dslr_train.txt'), patches=False, classes=31)