In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from random import randint
import time
import utils
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import cv2
import random

In [2]:
!pip install tqdm



In [27]:
import torch.utils.data.dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from losses import compute_contrastive_loss_from_feats
from utils import *  # bad practice, nvm
from models import *

# from dataset import ImageDataset
from training_config import doodles, reals, doodle_size, real_size, NUM_CLASSES

In [4]:
ckpt_dir = 'exp_data'

In [28]:
def combined_dataset(datasets, size):
    combined_dataset = {}
    for name, dataset in datasets.items():
        for class_name, class_data in dataset.items():
            if class_name not in combined_dataset:
                combined_dataset[class_name] = []
            # resize data so they can be stacked
            resized = []
            for data in class_data:
                resized.append(cv2.resize(data, (size, size), interpolation=cv2.INTER_AREA))
            resized = np.stack(resized, axis=0)
            combined_dataset[class_name].append(resized)
    for class_name, lst_datasets in combined_dataset.items():
        combined_dataset[class_name] = np.concatenate(lst_datasets, axis=0)
    return combined_dataset


class ImageDataset(Dataset):
    DATASET_DIR = {True: 'dataset/dataset_train.npy', False: 'dataset/dataset_test.npy'}

    def __init__(self, doodles_list, real_list, doodle_size, real_size, train: bool):
        super(ImageDataset, self).__init__()

        dataset = np.load(self.DATASET_DIR[train], allow_pickle=True)[()]

        doodle_datasets = {name: data for name, data in dataset.items() if name in doodles_list}
        real_datasets = {name: data for name, data in dataset.items() if name in real_list}
        self.doodle_dict = combined_dataset(doodle_datasets, doodle_size)
        self.real_dict = combined_dataset(real_datasets, real_size)

        # sanity check
        assert set(self.doodle_dict.keys()) == set(self.real_dict.keys()), \
            f'doodle and real images label classes do not match'

        # process classes
        label_idx = {}
        for key in self.doodle_dict.keys():
            if key not in label_idx:
                label_idx[key] = len(label_idx)
        self.label_idx = label_idx

        # parse data and labels
        self.doodle_data, self.doodle_label = self._return_x_y_pairs(self.doodle_dict, label_idx)
        self.real_data, self.real_label = self._return_x_y_pairs(self.real_dict, label_idx)

        # data preprocessing
        self.doodle_preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(doodle_size),
            transforms.ToTensor(),
            transforms.Normalize((self.doodle_data/255).mean(), (self.doodle_data/255).std())   # IMPORTANT / 255
        ])

        self.real_preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(real_size),
            transforms.ToTensor(),
            transforms.Normalize((self.real_data/255).mean(axis=(0, 1, 2)), (self.real_data/255).std(axis=(0, 1, 2)))
        ])

        print(f'Train = {train}. Doodle list: {doodles_list}, \n real list: {real_list}. \n classes: {label_idx.keys()} \n'
              f'Doodle data size {len(self.doodle_data)}, real data size {len(self.real_data)}, '
              f'ratio {len(self.doodle_data)/len(self.real_data)}')

    def _return_x_y_pairs(self, data_dict, category_mapping):
        xs, ys = [], []
        for key in data_dict.keys():
            data = data_dict[key]
            labels = [category_mapping[key]] * len(data)
            xs.append(data)
            ys.extend(labels)
        return np.concatenate(xs, axis=0), np.array(ys)

    def __getitem__(self, idx):
        # naive sampling scheme - sample with replacement
        # sample label first so that doodle and real data belong to the same category
        label = random.choice(list(self.label_idx.keys()))
        doodle_data = self.doodle_preprocess(random.choice(self.doodle_dict[label]))
        real_data = self.real_preprocess(random.choice(self.real_dict[label]))
        numer_label = self.label_idx[label]
        return doodle_data, numer_label, real_data, numer_label

    def __len__(self):
        return max(len(self.doodle_data), len(self.real_data)) # could be arbitrary number

In [44]:
# def convbn(in_channels, out_channels, kernel_size, stride, padding, bias):
#     return nn.Sequential(
#         nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
#         nn.BatchNorm2d(out_channels),
#         nn.ReLU(inplace=True)
#     )

# class V2ConvNet(nn.Module):
#     CHANNELS = [64, 128, 192, 256, 512]
#     POOL = (1, 1)

#     def __init__(self, in_c, num_classes, dropout=0.2, add_layers=False):
#         super().__init__()
#         layer1 = convbn(in_c, self.CHANNELS[1], kernel_size=3, stride=2, padding=1, bias=True)
#         layer2 = convbn(self.CHANNELS[1], self.CHANNELS[2], kernel_size=3, stride=2, padding=1, bias=True)
#         layer3 = convbn(self.CHANNELS[2], self.CHANNELS[3], kernel_size=3, stride=2, padding=1, bias=True)
#         layer4 = convbn(self.CHANNELS[3], self.CHANNELS[4], kernel_size=3, stride=2, padding=1, bias=True)
#         pool = nn.AdaptiveAvgPool2d(self.POOL)
#         self.layers = nn.Sequential(layer1, layer2, layer3, layer4, pool)

#         if add_layers:
#             layer1_2 = convbn(self.CHANNELS[1], self.CHANNELS[1], kernel_size=3, stride=1, padding=0, bias=True)
#             layer2_2 = convbn(self.CHANNELS[2], self.CHANNELS[2], kernel_size=3, stride=1, padding=0, bias=True)
#             layer3_2 = convbn(self.CHANNELS[3], self.CHANNELS[3], kernel_size=3, stride=1, padding=0, bias=True)
#             layer4_2 = convbn(self.CHANNELS[4], self.CHANNELS[4], kernel_size=3, stride=1, padding=0, bias=True)
#             self.layers = nn.Sequential(layer1, layer1_2, layer2, layer2_2, layer3, layer3_2, layer4, layer4_2, pool)

#         self.nn = nn.Linear(self.POOL[0] * self.POOL[1] * self.CHANNELS[4], num_classes)
#         self.dropout = nn.Dropout(p=dropout)

#     def forward(self, x, return_feats=False):
#         feats = self.layers(x).flatten(1)
#         x = self.nn(self.dropout(feats))

#         if return_feats:
#             return x, feats

#         return x

class V2ConvNet(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, num_classes=9):
        super(V2ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.stride = stride
        
        self.pool = nn.AvgPool2d(2)
        
        self.lin1 = nn.Linear(256, 128)
        self.lin2 = nn.Linear(128, num_classes)

    def forward(self, x, return_feats=False):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        feats = self.relu(out)
        
        x = out.view(out.size(0), 256, -1).mean(2) # global average pooling
        
        x = self.lin1(x)
        pred = self.lin2(x)
        
        if return_feats:
            return pred, feats

        return pred

In [45]:
x = torch.rand(100, 3, 64, 64)
net = V2ConvNet(3, 64, num_classes=9)
y = net(x)
print (y.shape)

torch.Size([100, 9])


In [47]:
def train_model(model1, model2, train_set, val_set, tqdm_on, id, num_epochs, batch_size, learning_rate, c1, c2, t):
    # cuda side setup
    model1 = nn.DataParallel(model1).cuda()
    model2 = nn.DataParallel(model2).cuda()

    # training side
    optimizer = torch.optim.AdamW(params=list(model1.parameters()) + list(model2.parameters()),
                                  lr=learning_rate, weight_decay=3e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # load the training data
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                              num_workers=16, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=16,
                            pin_memory=True, drop_last=True)

    # training loop
    for epoch in range(num_epochs):
        loss1_model1 = AverageMeter()
        loss1_model2 = AverageMeter()
        loss2_model1 = AverageMeter()
        loss2_model2 = AverageMeter()
        loss3_combined = AverageMeter()
        acc_model1 = AverageMeter()
        acc_model2 = AverageMeter()

        model1.train()
        model2.train()
        pg = tqdm(train_loader, leave=False, total=len(train_loader), disable=not tqdm_on)
        for i, (x1, y1, x2, y2) in enumerate(pg):
            # doodle, label, real, label
            x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()

            # train model1 (doodle)
            pred1, feats1 = model1(x1, return_feats=True)
            loss_1 = criterion(pred1, y1)    # classification loss
            loss1_model1.update(loss_1)
            loss_model1 = loss_1

            # train model2 (real)
            pred2, feats2 = model2(x2, return_feats=True)
            loss_1 = criterion(pred2, y2)   # classification loss
            loss1_model2.update(loss_1)
            loss_model2 = loss_1

            loss = loss_model1 + loss_model2

            # statistics
            acc_model1.update(compute_accuracy(pred1, y1))
            acc_model2.update(compute_accuracy(pred2, y2))

            # optimization
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # display
            pg.set_postfix({
                'acc 1': '{:.6f}'.format(acc_model1.avg),
                'acc 2': '{:.6f}'.format(acc_model2.avg),
                'l1m1': '{:.6f}'.format(loss1_model1.avg),
                'l1m2': '{:.6f}'.format(loss1_model2.avg),
                'train epoch': '{:03d}'.format(epoch)
            })

        print(
            f'train epoch {epoch}, acc 1={acc_model1.avg:.3f}, acc 2={acc_model2.avg:.3f}, l1m1={loss1_model1.avg:.3f}, 'f'l1m2={loss1_model2.avg:.3f}')

        # validation
        model1.eval(), model1.eval()
        acc_model1.reset(), acc_model2.reset()
        pg = tqdm(val_loader, leave=False, total=len(val_loader), disable=not tqdm_on)
        with torch.no_grad():
            for i, (x1, y1, x2, y2) in enumerate(pg):
                pred1, feats1 = model1(x1, return_feats=True)
                pred2, feats2 = model2(x2, return_feats=True)
                acc_model1.update(compute_accuracy(pred1, y1))
                acc_model2.update(compute_accuracy(pred2, y2))

                # display
                pg.set_postfix({
                    'acc 1': '{:.6f}'.format(acc_model1.avg),
                    'acc 2': '{:.6f}'.format(acc_model2.avg),
                    'val epoch': '{:03d}'.format(epoch)
                })

        print(f'validation epoch {epoch}, acc 1 (doodle) = {acc_model1.avg:.3f}, acc 2 (real) = {acc_model2.avg:.3f}')

        scheduler.step()

    print(f'training finished')

    # save checkpoint
    exp_dir = f'exp_data/{id}'
    save_model(exp_dir, f'{id}_model1.pt', model1)
    save_model(exp_dir, f'{id}_model2.pt', model2)

In [48]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [31]:
fix_seed(0)

train_set = ImageDataset(doodles, reals, doodle_size, real_size, train=True)
val_set = ImageDataset(doodles, reals, doodle_size, real_size, train=False)

Train = True. Doodle list: ['sketchy_doodle', 'tuberlin', 'google_doodles'], 
 real list: ['sketchy_real', 'google_real', 'cifar']. 
 classes: dict_keys(['airplane', 'car', 'cat', 'dog', 'frog', 'horse', 'truck', 'bird', 'ship']) 
Doodle data size 7022, real data size 46364, ratio 0.15145371408851696
Train = False. Doodle list: ['sketchy_doodle', 'tuberlin', 'google_doodles'], 
 real list: ['sketchy_real', 'google_real', 'cifar']. 
 classes: dict_keys(['airplane', 'car', 'cat', 'dog', 'frog', 'horse', 'truck', 'bird', 'ship']) 
Doodle data size 1764, real data size 9341, ratio 0.18884487742211756


In [None]:
# tunable hyper params.
use_cnn = True
num_epochs, base_bs, base_lr = 15, 512, 2e-2
c1, c2, t = 0, 0, 0.1  # contrastive learning. if you want vanilla (cross-entropy) training, set c1 and c2 to 0.
dropout = 0.3

# models
doodle_model = V2ConvNet(1, 64, num_classes=9)
real_model = V2ConvNet(3, 64, num_classes=9)

# just some logistics
tqdm_on = True     # progress bar
id = 25             # change to the id of each experiment accordingly

train_model(doodle_model, real_model, train_set, val_set, tqdm_on, id, num_epochs, base_bs, base_lr, c1, c2, t)

                                                                                               

train epoch 0, acc 1=0.177, acc 2=0.261, l1m1=2.256,l1m2=2.064


                                                                                               

validation epoch 0, acc 1 (doodle) = 0.220, acc 2 (real) = 0.272


                                                                                               

train epoch 1, acc 1=0.280, acc 2=0.332, l1m1=1.872,l1m2=1.776


                                                                                               

validation epoch 1, acc 1 (doodle) = 0.302, acc 2 (real) = 0.333


                                                                                               

train epoch 2, acc 1=0.326, acc 2=0.387, l1m1=1.765,l1m2=1.607


                                                                                               

validation epoch 2, acc 1 (doodle) = 0.358, acc 2 (real) = 0.372


                                                                                               

train epoch 3, acc 1=0.343, acc 2=0.425, l1m1=1.719,l1m2=1.522


                                                                                               

validation epoch 3, acc 1 (doodle) = 0.325, acc 2 (real) = 0.388


                                                                                               

train epoch 4, acc 1=0.355, acc 2=0.464, l1m1=1.683,l1m2=1.437


                                                                                               

validation epoch 4, acc 1 (doodle) = 0.356, acc 2 (real) = 0.430


                                                                                               

train epoch 5, acc 1=0.370, acc 2=0.497, l1m1=1.656,l1m2=1.369


                                                                                               

validation epoch 5, acc 1 (doodle) = 0.327, acc 2 (real) = 0.449


                                                                                               

train epoch 6, acc 1=0.376, acc 2=0.513, l1m1=1.633,l1m2=1.328


                                                                                               

validation epoch 6, acc 1 (doodle) = 0.369, acc 2 (real) = 0.476


                                                                                               

train epoch 7, acc 1=0.387, acc 2=0.538, l1m1=1.601,l1m2=1.273


                                                                                               

validation epoch 7, acc 1 (doodle) = 0.351, acc 2 (real) = 0.479


                                                                                               

train epoch 8, acc 1=0.396, acc 2=0.550, l1m1=1.588,l1m2=1.241


                                                                                               

validation epoch 8, acc 1 (doodle) = 0.386, acc 2 (real) = 0.464


                                                                                               

train epoch 9, acc 1=0.401, acc 2=0.571, l1m1=1.576,l1m2=1.195


                                                                                               

validation epoch 9, acc 1 (doodle) = 0.369, acc 2 (real) = 0.487


                                                                                               

train epoch 10, acc 1=0.415, acc 2=0.575, l1m1=1.542,l1m2=1.177


                                                                                               

validation epoch 10, acc 1 (doodle) = 0.382, acc 2 (real) = 0.503


                                                                                               

train epoch 11, acc 1=0.420, acc 2=0.581, l1m1=1.533,l1m2=1.158


                                                                                               

validation epoch 11, acc 1 (doodle) = 0.388, acc 2 (real) = 0.504


                                                                                               

train epoch 12, acc 1=0.433, acc 2=0.599, l1m1=1.509,l1m2=1.120


                                                                                               

validation epoch 12, acc 1 (doodle) = 0.385, acc 2 (real) = 0.508


                                                                                               

train epoch 13, acc 1=0.439, acc 2=0.597, l1m1=1.496,l1m2=1.116


                                                                                               

validation epoch 13, acc 1 (doodle) = 0.408, acc 2 (real) = 0.512


 73%|▋| 66/90 [00:35<00:11,  2.03it/s, acc 1=0.440607, acc 2=0.598455, l1m1=1.498335, l1m2=1.11