In [None]:
import torch

from data import data_helper
from data.data_helper import available_datasets
from models import model_factory
from optimizer.optimizer_helper import get_optim_and_scheduler
from torch.nn import functional as F
from torch import nn

from utils.Logger import Logger

import torchvision
import matplotlib.pyplot as plt

from train_jigsaw import Trainer
from utils import vis

class Container():
    pass

args = Container()
args.batch_size = 128
args.n_classes = 7
args.learning_rate = 0.001
args.epochs = 30
args.network = "caffenet"
args.val_size = 0.1
args.tf_logger = True
args.folder_name = "test"
args.jigsaw_n_classes = 31
args.classify_only_sane = True
args.jig_weight = 0.9
args.bias_whole_image = 0.8
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

source = sorted(["photo", "cartoon", "sketch", "art_painting"])
for args.TTA in [True, False]: 
    for k, x in enumerate(source):
        args.source = source[:k]+source[k+1:]
        args.target = x
        for i in range(3):
            print("\n%s to %s - %d jigsaw classes, split %d" % ("-".join(args.source), 
                                                              args.target, 
                                                              args.jigsaw_n_classes,
                                                              i))
            trainer = Trainer(args, device)
            trainer.do_training()


cartoon-photo-sketch to art_painting - 31 jigsaw classes, split 0
Using Caffe AlexNet
Dataset size: train 7150, val 793, test 2048
Step size: 24
Saving to /home/enoon/code/2018/JigsawDA/myJigsaw/utils/../logs/test/cartoon-photo-sketch_to_art_painting/eps30_bs128_lr0.001_class7_jigClass31_jigWeight0.9_bias0.8_classifyOnlySane_TTA_855
New epoch - lr: 0.0, 0.001
0/55 of epoch 1/30 jigsaw : 3.456380, class : 1.990755 - acc jigsaw : 0.046875, class : 0.140625 [bs:128]
30/55 of epoch 1/30 jigsaw : 0.779950, class : 0.765542 - acc jigsaw : 0.812500, class : 0.703125 [bs:128]
Accuracies on val: jigsaw : 1.000000, class : 0.844893
Single vs multi: 0.605469 0.592773
Accuracies on test: jigsaw : 1.000000, class : 0.592773
New epoch - lr: 0.0, 0.001
0/55 of epoch 2/30 jigsaw : 0.770560, class : 0.392745 - acc jigsaw : 0.804688, class : 0.765625 [bs:128]
30/55 of epoch 2/30 jigsaw : 0.971272, class : 0.550318 - acc jigsaw : 0.773438, class : 0.726562 [bs:128]
Accuracies on val: jigsaw : 1.000000, 

In [None]:
source = sorted(["photo", "cartoon", "sketch", "art_painting"])
for args.jig_weight in [0.0]:
    print("\n======================\n%g\n===================" % args.jig_weight)
    for args.bias_whole_image in [1.0]: # [None, 0.01, 0.05, 0.1, 0.3, 0.5]:
        for k, x in enumerate(source):
            args.source = source[:k]+source[k+1:]
            args.target = x
            for i in range(3):
                print("\n%s to %s - %d jigsaw classes, split %d" % ("-".join(args.source), 
                                                                  args.target, 
                                                                  args.jigsaw_n_classes,
                                                                  i))
                trainer = Trainer(args, device)
                trainer.do_training()

In [None]:
trainer = Trainer(args, device)
logger, model = trainer.do_training()

In [None]:
%matplotlib notebook
print(100*(logger.val_acc["class"][-1] + logger.val_acc["class"][-2])/2.)
vis.view_training(logger, "%s->%s eps:%d jigweight:%.1f" % ("-".join(args.source),
                                                            args.target,args.epochs, args.jig_weight))

In [None]:
dataloaders = {"train": data_helper.get_dataloader(args.source, args.jigsaw_n_classes, "train"),
           "val": data_helper.get_dataloader(args.target, args.jigsaw_n_classes, "val")}
dataset_sizes = {"train": len(dataloaders["train"].dataset),
                "val": len(dataloaders["val"].dataset)}
print(dataset_sizes)

In [None]:
def get_optim_and_scheduler(network, epochs, lr):
    from torch import optim
    optimizer = optim.SGD(network.get_params(lr), weight_decay=.0005, momentum=.9, nesterov=True, lr=lr)
    step_size = int(epochs * .8)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size)
    print(step_size)
    return optimizer, scheduler

model_ft = model_factory.get_network("caffenet")(jigsaw_classes=args.jigsaw_n_classes+1, classes=args.n_classes)
model_ft = model_ft.to(device)
# print(model_ft)

optimizer, scheduler = get_optim_and_scheduler(model_ft, args.epochs, args.learning_rate)

In [None]:
def do_epoch(model, source, target, optimizer, logger, device):
    criterion = nn.CrossEntropyLoss()
    model.train()
    for it, ((data, jig_l, class_l), d_idx) in enumerate(source):
        data, jig_l, class_l = data.to(device), jig_l.to(device), class_l.to(device)
        
        optimizer.zero_grad()
        
        jigsaw_logit, class_logit = model(data)
        jigsaw_loss = criterion(jigsaw_logit, jig_l)
        class_loss = criterion(class_logit[d_idx!=target_id], class_l[d_idx!=target_id])
        _, cls_pred = class_logit.max(dim=1)
        _, jig_pred = jigsaw_logit.max(dim=1)
        loss = class_loss + jigsaw_loss * jig_weight
        
        loss.backward()
        optimizer.step()

        logger.log(it, len(source), {"jigsaw": jigsaw_loss.item(), "class": class_loss.item()},
                  {"jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class":torch.sum(cls_pred == class_l.data).item()},
                  data.shape[0])
        del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit

    model.eval()
    with torch.no_grad():
        jigsaw_correct = 0
        class_correct = 0
        total = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(target):
            data, jig_l, class_l = data.to(device), jig_l.to(device), class_l.to(device)
            jigsaw_logit, class_logit = model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
            total += data.shape[0]
        logger.log_test({"jigsaw": float(jigsaw_correct) / total,
                         "class": float(class_correct) / total})


def do_training(args, model, source, target, optimizer, scheduler, device):
    logger = Logger(args)
    for k in range(args.epochs):
        scheduler.step()
        logger.new_epoch(scheduler.get_lr())
        do_epoch(model, source, target, optimizer, logger, device)
    return logger, model

In [None]:
jig_weight = args.jig_weight
logger, model = do_training(args, model_ft, dataloaders["train"], dataloaders["val"], optimizer, scheduler, device)

In [None]:
%matplotlib notebook
print(100*(logger.val_acc["class"][-1] + logger.val_acc["class"][-2])/2.)
vis.view_training(logger, "%s->%s eps:%d jigweight:%.1f" % (str(args.source),args.target,args.epochs, jig_weight))

In [None]:
def to_plt(inp):
    import numpy as np
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    return inp

conv1 = model_ft.features[0] # models.alexnet(pretrained=True).features[0] #model_ft.features[0]
tmp = conv1.weight.cpu().data
tmp = torchvision.utils.make_grid(tmp,normalize=True)
plt.imshow(to_plt(tmp))
plt.show()

In [None]:
vis.view_training(logger, "%s->%s eps:%d jigweight:%.1f" % (str(source),target,epochs, jig_weight))

In [None]:
fig, ax1 = plt.subplots()
for k,v in logger.losses.items():
    ax1.plot(v, label=k)
    l = len(v)
updates = l / len(logger.val_acc)
print(updates)
plt.legend()
ax2 = ax1.twinx()
ax2.plot(range(0,l,int(updates)), logger.val_acc, label="Test acc", c='g')
plt.legend()
plt.show()

In [None]:
for e,k in enumerate(range(0,l,int(updates))):
    print(k, logger.val_acc[e])

In [None]:



# iter_c = iter(train_datasets)

# for x in range(5):
#     tmp = next(iter_c)
#     image = to_plt(tmp[0])
#     plt.imshow(image)
#     plt.show()

In [None]:
from data.data_helper import get_val_dataloader
from os.path import join, dirname
# from data.JigsawLoader import JigsawTestDataset
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision

In [None]:
loader = get_val_dataloader("photo",31,batch_size=10,multi=True)

In [None]:
def to_plt(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    return inp

# dataset = JigsawTestDataset("", join('data/txt_lists', 'dslr_train.txt'), patches=False, classes=31)
# test = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)
iter_c = iter(loader)
(tmp, v, c), d = next(iter_c)
print(tmp[:,0].shape)
for x in range(tmp.shape[1]):
#     image = tmp[0, x]
    image = torchvision.utils.make_grid(tmp[0, x],1,normalize=True)
    plt.imshow(to_plt(image))
    plt.show()
    print(v[0,x])

    
# print(v.max(), v.min())

In [None]:
res = []
for k in range(10):
    res.append(torch.rand(5, 100))
res = torch.stack(res,0)
res.mean(1).shape

In [None]:
from data.JigsawLoader import JigsawDataset
from PIL import Image
import torchvision.transforms as transforms

class JigsawTestDataset(JigsawDataset):
    def __init__(self, *args, **xargs):
        super().__init__(*args, **xargs)
        self._augment_tile = transforms.Compose([
#             transforms.RandomCrop(64),
            transforms.Resize((75, 75), Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
    def __getitem__(self, index):
        framename = self.data_path + '/' + self.names[index]
        img = Image.open(framename).convert('RGB')
        img = self._image_transformer(img)

        w = float(img.size[0]) / self.grid_size
        n_grids = self.grid_size ** 2
        tiles = [None] * n_grids
        for n in range(n_grids):
            y = int(n / self.grid_size)
            x = n % self.grid_size
            tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w])
            tile = self._augment_tile(tile)
            tiles[n] = tile

        data = torch.stack(tiles, 0)
        return self.returnFunc(data), 0, int(self.labels[index])


In [None]:
dataset = JigsawTestDataset("", join('data/txt_lists', 'dslr_train.txt'), patches=False, classes=31)

In [None]:
import torch
import random

In [None]:
true = 0
for x in range(100):
    true += 0.1 > random.random()
true