In [1]:
import os
import numpy as np
import logging
import pickle

import torch
import torch.nn as nn

import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torchvision.datasets import cifar

!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
import warmup_scheduler
from autoaugment import CIFAR10Policy
torch.backends.cudnn.enabled = True

Collecting git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
  Cloning https://github.com/ildoonet/pytorch-gradual-warmup-lr.git to /tmp/pip-req-build-a5mi8c3f
  Running command git clone -q https://github.com/ildoonet/pytorch-gradual-warmup-lr.git /tmp/pip-req-build-a5mi8c3f
  Resolved https://github.com/ildoonet/pytorch-gradual-warmup-lr.git to commit 6b5e8953a80aef5b324104dc0c2e9b8c34d622bd
  Preparing metadata (setup.py) ... [?25ldone
[?25h

In [2]:
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f95fd70f150>

In [3]:
n_channels= 384
n_layers = 3

num_samples = 1024
batch_size_train_mem = 64
batch_size_train_cls = 128
batch_size_test = 128

In [4]:
max_train_samples = num_samples if num_samples<5000 else f'{num_samples//1000}k'

In [5]:
def cifar10(batch_num, max_samples):
    torchvision.datasets.cifar.CIFAR10(
        root='./data', train=True, download=True)
    with open(f'./data/cifar-10-batches-py/data_batch_{batch_num}', 
              'rb') as f:
        batch = pickle.load(f, encoding="latin1")
        samples = batch['data'][:max_samples].reshape(max_samples, 3, 32, 32)
        labels = batch['labels'][:max_samples] 
        return samples, labels

In [6]:
from collections import Counter

numclasses = 10
bitlength = numclasses 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def grayN(base, digits, value):
    baseN = torch.zeros(digits)
    gray = torch.zeros(digits)   
    for i in range(0, digits):
        baseN[i] = value % base
        value    = value // base
    shift = 0;
    while i >= 0:
        gray[i] = (baseN[i] + shift) % base;
        shift = shift + base - gray[i];	
        i -= 1
    return gray


In [7]:
class CustomCIFAR(Dataset):
    def __init__(self, transform=None,
                 max_samples=1024):
        self.transform = transform
        #loading
        (train_X, train_y) = cifar10(1, max_samples)
 
        self.data = train_X
        self.targets = train_y
        #create index+class embeddings, and a reverse lookup
        self.C = Counter()
        self.cbinIndexes = np.zeros((len(self.targets), bitlength))
        self.inputs = []
        self.input2index = {}

        with torch.no_grad():
            for i in range(len(self.data)):
                label = int(self.targets[i])
                self.C.update(str(label))
                class_code = torch.zeros(numclasses)
                class_code[int(self.targets[i])] = 3
                self.cbinIndexes[i] = grayN(3, 10, self.C[str(label)]) +  class_code

                
                input = torch.tensor(self.cbinIndexes[i]).float()
                self.inputs.append( input )
                self.input2index[( label, self.C[str(label)] )] = i

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index: int):
          
        img, target = self.data[index], int(self.targets[index])
        img = torch.from_numpy(img) / 255

        label = torch.zeros(numclasses).float()
        label[target] = 1
        return self.inputs[index].to(device), label.to(device), img.to(device)


In [8]:
train_loader_mem = torch.utils.data.DataLoader(
  CustomCIFAR(transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ]), max_samples=num_samples),
  batch_size=batch_size_train_mem, shuffle=True)

train_loader_cls = torch.utils.data.DataLoader(
    torchvision.datasets.cifar.CIFAR10(
        root='./data', train=True,
        transform= torchvision.transforms.Compose([
                                      torchvision.transforms.RandomCrop(size=32, padding=3),
                                      CIFAR10Policy(),
                                      torchvision.transforms.ToTensor()])),
  batch_size=batch_size_train_cls, shuffle=True, pin_memory=True)


test_loader_mem = torch.utils.data.DataLoader(cifar.CIFAR10(
    root='./data', train=False, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])), batch_size=batch_size_test)

Files already downloaded and verified


In [9]:
class Conv_Layer(nn.Module):  
    def __init__(self, in_channels, out_channels):
        super(Conv_Layer,self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, bias=True,
                               out_channels=out_channels,
                               stride=1,kernel_size=(3,3),padding=0)
    def forward(self, x):
        x = self.conv(x)
        x = torch.relu(x)
        return x

    def forward_transposed(self, code):
        code = F.conv_transpose2d(code, self.conv.weight.data, 
                                          padding=0)
        code = torch.relu(code)
        return code
    
class CNN(nn.Module):  
    def __init__(self, n_layers, n_channels):
        super(CNN,self).__init__()
        self.n_channels = n_channels
        self.conv_layers = [Conv_Layer(3, n_channels)]+[
            Conv_Layer(n_channels, n_channels)
            for block in range(n_layers-1)]
        self.conv_layers_forward = nn.Sequential(*self.conv_layers)   
        
        self.avg_pool = nn.AvgPool2d(kernel_size=(2,2),stride=2)
        self.linear1 = nn.Linear(n_channels*13*13, n_channels, bias=True)
        self.linear2 = nn.Linear(n_channels, 10, bias=True)
        
    def forward(self, x):
        x = self.conv_layers_forward(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x
    
    def forward_transposed(self, code):
        code = torch.matmul(code, self.linear2.weight)
        code = torch.relu(code)
        code = torch.matmul(code,
                                  self.linear1.weight)
        code = code.view(code.size(0), self.n_channels, 13, 13)
        code = F.interpolate(code, scale_factor=2,
                             recompute_scale_factor=False)        
        for layer in self.conv_layers[::-1]:
            code = layer.forward_transposed(code)
        return code
    

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size =  int(3*32*32)
output_size =  int(numclasses)

model = CNN(n_layers = n_layers,
            n_channels=n_channels).to(device)

In [11]:
class LabelSmoothingCrossEntropyLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingCrossEntropyLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [12]:
# Adjust the number of training iterations and optimization settings to your likeings

CE = LabelSmoothingCrossEntropyLoss(classes=10, smoothing=0.2)
MSE = nn.MSELoss()
iterations = 1000
best_loss_r = np.inf

optimizer_cls = optim.Adam(model.parameters(), lr=1e-4,)

lr_scheduler_cls = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cls,
                                                          T_max=iterations, 
                                                              eta_min=1e-6)
scheduler_cls = warmup_scheduler.GradualWarmupScheduler(optimizer_cls, multiplier=1.,
                                                    total_epoch=5, after_scheduler=lr_scheduler_cls)


optimizer_mem = optim.Adam(model.parameters(), lr=1e-4,)
lr_scheduler_mem = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_mem,
                                                          T_max=iterations,
                                                              eta_min=1e-6)
scheduler_mem = warmup_scheduler.GradualWarmupScheduler(optimizer_mem, multiplier=1.,
                                                    total_epoch=5, after_scheduler=lr_scheduler_mem)


In [13]:
save_path = f'./models/cifar10_3gray_cnn_{max_train_samples}_{n_channels}channels_{n_layers}layers_split_forward.pt'
save_path

'./models/cifar10_3gray_cnn_1024_384channels_3layers_split_forward.pt'

In [None]:
if os.path.isfile(f'{save_path}.log'):
        os.remove(f'{save_path}.log')
logging.basicConfig(filename=f'{save_path}.log', level=logging.INFO)
logging.info('Start Training')

for epoch in range(iterations):
    loss_c = 0
    loss_r = 0
    loss = 0
    c=0
    for (code, _, imgs), (data, labels) in zip(train_loader_mem,
                                  train_loader_cls):
        data = data.to(device)
        code = code.to(device)
        imgs = imgs.to(device)
        labels = labels.to(device)
        

        optimizer_cls.zero_grad()
        optimizer_mem.zero_grad()
        predlabel = model(data)
        loss_classf = CE(predlabel,
                         labels)
        loss_classf.backward()   
        optimizer_cls.step()
        
        optimizer_mem.zero_grad()
        optimizer_cls.zero_grad()
        predimg = model.forward_transposed(code)
        loss_recon = MSE(predimg, imgs)
        loss_recon.backward()
        optimizer_mem.step()

        loss_c += loss_classf.item()
        loss_r += loss_recon.item()
        c+=1
    
    scheduler_cls.step()
    scheduler_mem.step()
    print("Iteration : {}/{}, loss_c = {:.6f}, loss_r = {:.6f}".format(epoch + 1, iterations, loss_c/c, loss_r/c))
    logging.info("Iteration : {}/{}, loss_c = {:.6f}, loss_r = {:.6f}".format(epoch + 1, iterations, loss_c/c, loss_r/c))    

    if loss_r/c < best_loss_r:
        model_state = {'net': model.state_dict(),
                       'opti_mem': optimizer_mem.state_dict(), 
                       'opti_cls': optimizer_cls.state_dict(), 
                       'loss_r': loss_r/c}
        torch.save(model_state, save_path)
        best_loss_r = loss_r/c

In [None]:
model.load_state_dict(torch.load(save_path)['net'])
torch.load(save_path)['loss_r']

In [None]:
correct=0
total = 0
model.eval()
with torch.no_grad():
    for (inputs, labels) in test_loader_mem:
        code = torch.zeros(inputs.size(0), 10, device=device)
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = model(inputs)
        ypred = output.max(dim=1, keepdim=True)[1].squeeze(1)
        correct += ypred.eq(labels).sum()
        total += ypred.size(0)
print("Acc", correct/total)

In [None]:
model.eval()
error_list = []
recon_list = []
org_list = []
label_list = []

with torch.no_grad():
    for codes, labels, imgs in train_loader_mem:
        imgs = imgs.to(device)
        imgrecon = model.forward_transposed(codes)
        error = ((imgs - imgrecon)**2).sum(dim=(1,2,3))/(3*32*32)
        error_list.append(error.cpu().numpy())
        recon_list.append(imgrecon.cpu())
        org_list.append(imgs.cpu())
        label_list.append(labels.cpu().numpy())
error_list = np.concatenate(error_list)
recon_list = torch.cat(recon_list, axis=0)
org_list = torch.cat(org_list, axis=0)
label_list = np.concatenate(label_list)