In [1]:
import os
import numpy as np

from PIL import Image
import pickle

import torch
import torch.nn as nn

import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.optim as optim
from torchvision.datasets import cifar

!wget https://postbguacil-my.sharepoint.com/:f:/g/personal/guy5_post_bgu_ac_il/EjUkvMSgGsVAj2Lz6mVe7twBDRMflr-ADP1BMPJY8-eJYQ?e=6lDbkk
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
import warmup_scheduler
from autoaugment import CIFAR10Policy
torch.backends.cudnn.enabled = True

In [None]:
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

In [3]:
mlp_hidden=384*3
hidden = 384
num_layers=7
head=12

num_samples = 1024
batch_size_train_mem = 64
batch_size_train_cls = 128
batch_size_test = 128

In [4]:
max_train_samples = num_samples if num_samples<5000 else f'{num_samples//1000}k'

In [5]:
def cifar10(batch_num, max_samples):
    with open(f'./data/cifar-10-batches-py/data_batch_{batch_num}', 
              'rb') as f:
        batch = pickle.load(f, encoding="latin1")
        samples = batch['data'][:max_samples].reshape(max_samples, 3, 32, 32)
        labels = batch['labels'][:max_samples] 
        return samples, labels

In [6]:
from collections import Counter

numclasses = 10
bitlength = numclasses 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def grayN(base, digits, value):
    baseN = torch.zeros(digits)
    gray = torch.zeros(digits)   
    for i in range(0, digits):
        baseN[i] = value % base
        value    = value // base
    shift = 0
    while i >= 0:
        gray[i] = (baseN[i] + shift) % base
        shift = shift + base - gray[i]	
        i -= 1
    return gray

E = torch.nn.Embedding(numclasses, numclasses)

max dataset size= 59049


In [7]:
class CustomCIFAR(Dataset):
    def __init__(self, transform=None,
                 max_samples=1024):
        self.transform = transform
        #loading
        (train_X, train_y) = cifar10(1, max_samples)
 
        self.data = train_X
        self.targets = np.array(train_y)

        #create index+class embeddings, and a reverse lookup
        self.C = Counter()
        self.class_embedding = nn.Embedding(10, 10)
        self.class_embedding.requires_grad_(False)
        self.cbinIndexes = np.zeros((len(self.targets), bitlength))
        self.inputs = []
        self.input2index = {}
        

        with torch.no_grad():
            for i in range(len(self.data)):
                label = int(self.targets[i])
                self.C.update(str(label))   
                class_code = torch.zeros(numclasses) #5
                class_code[int(self.targets[i])] = 3
                self.cbinIndexes[i] = grayN(3, 10, self.C[str(label)]) +  class_code

                
                input = torch.tensor(self.cbinIndexes[i]).float()
                self.inputs.append( input )
                self.input2index[( label, self.C[str(label)] )] = i

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index: int):
          
        img, target = self.data[index], int(self.targets[index])
        img = torch.from_numpy(img) / 255

        label = torch.zeros(numclasses).float()
        label[target] = 1
        return self.inputs[index].to(device), label.to(device), img.to(device)


In [8]:
train_loader_mem = torch.utils.data.DataLoader(
  CustomCIFAR(transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ]), max_samples=num_samples),
  batch_size=batch_size_train_mem, shuffle=True)

train_loader_cls = torch.utils.data.DataLoader(
    torchvision.datasets.cifar.CIFAR10(
        root='./data', train=True,
        transform= torchvision.transforms.Compose([
                                      torchvision.transforms.RandomCrop(size=32, padding=3),
                                      CIFAR10Policy(),
                                      torchvision.transforms.ToTensor()])),
  batch_size=batch_size_train_cls, shuffle=True, pin_memory=True)


test_loader = torch.utils.data.DataLoader(cifar.CIFAR10(
    root='./data', train=False, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])), batch_size=batch_size_test)

In [9]:
import torch
import torch.nn as nn


class TransformerEncoder(nn.Module):
    def __init__(self, feats:int, mlp_hidden:int, head:int=8, dropout:float=0.):
        super(TransformerEncoder, self).__init__()
        self.la1 = nn.LayerNorm(feats)
        self.msa = MultiHeadSelfAttention(feats,
                                          head=head,
                                          dropout=dropout)
        self.la2 = nn.LayerNorm(feats)
        self.mlp = nn.Sequential(
            nn.Linear(feats, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, feats),
            nn.GELU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        out = self.msa(self.la1(x)) + x
        out = self.mlp(self.la2(out)) + out
        return out


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, feats:int, head:int=8, dropout:float=0.):
        super(MultiHeadSelfAttention, self).__init__()
        self.head = head
        self.feats = feats
        self.sqrt_d = self.feats**0.5

        self.q = nn.Linear(feats, feats)
        self.k = nn.Linear(feats, feats)
        self.v = nn.Linear(feats, feats)

        self.o = nn.Linear(feats, feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, n, f = x.size()
        q = self.q(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        k = self.k(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        v = self.v(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)

        score = F.softmax(torch.einsum("bhif, bhjf->bhij", q, k)/self.sqrt_d, dim=-1) #(b,h,n,n)
        attn = torch.einsum("bhij, bhjf->bihf", score, v) #(b,n,h,f//h)
        o = self.dropout(self.o(attn.flatten(2)))
        return o
    

class ViT(nn.Module):
    def __init__(self, in_c:int=3, num_classes:int=10, img_size:int=32, patch:int=8,
                 dropout:float=0., num_layers:int=7, hidden:int=416, 
                 mlp_hidden:int=416*4, head:int=8):
        super(ViT, self).__init__()

        self.hidden = hidden
        self.patch = patch # number of patches in one row(or col)
        self.patch_size = img_size//self.patch
        f = (img_size//self.patch)**2*3 # 48 # patch vec length
        self.num_tokens = self.patch**2

        self.emb = nn.Linear(f, hidden) # (b, n, f)
        # self.cls_token = nn.Parameter(torch.randn(1, 1, hidden)) if is_cls_token else None
        self.pos_emb = nn.Parameter(torch.randn(1,self.num_tokens, hidden))
        enc_list = [TransformerEncoder(hidden,
                                       mlp_hidden=mlp_hidden,
                                       dropout=dropout,
                                       head=head) for _ in range(num_layers)]
        
        enc_list_reversed = enc_list[-1::]
        
        self.enc = nn.Sequential(*enc_list)
        self.enc_reversed = nn.Sequential(*enc_list_reversed)
        
        self.fc = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Linear(hidden, num_classes) # for cls_token
        )


    def forward(self, x):
        out = self._to_words(x)
        out = self.emb(out)
        out = out + self.pos_emb
        out = self.enc(out)
        out = out.mean(1)
        out = self.fc(out)
        return out
    
    def forward_transposed(self, code):
        code = torch.matmul(code, self.fc[1].weight)
        code = self.fc[0](code)
        code = code.reshape(code.size(0), 1, self.hidden) + self.pos_emb
        
        code = self.enc_reversed(code)
        code = torch.matmul(code, self.emb.weight)
        img = self._from_words(code)
        return img

    def _to_words(self, x):
        """
        (b, c, h, w) -> (b, n, f)
        """
        out = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size).permute(0,2,3,4,5,1)
        out = out.reshape(x.size(0), self.patch**2 ,-1)
        return out
    
    def _from_words(self, x):
        """
        (b, n, f) -> (b, c, h, w)
        """
        x = x.reshape(x.size(0), self.patch**2, 3, self.patch_size, self.patch_size)
        b, p, c, ph, pw = x.shape
        sh, sw = 8, 8
        x = x.view(b, sh, sw, c, ph, pw)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(b, c, 32, 32)
        return x
    

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size =  int(3*32*32)

output_size =  int(numclasses)

# load it to the specified device, either gpu or cpu
model = ViT(hidden=hidden, mlp_hidden=mlp_hidden,
            num_layers=num_layers, head=head).to(device)

In [12]:
class LabelSmoothingCrossEntropyLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingCrossEntropyLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [13]:
# mean-squared error loss
CE = LabelSmoothingCrossEntropyLoss(classes=10, smoothing=0.2)
MSE = nn.MSELoss()
iterations = 10000
best_loss_r = np.inf

optimizer_cls = optim.AdamW(model.parameters(), lr=1e-4,)

lr_scheduler_cls = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cls,
                                                          T_max=iterations, 
                                                              eta_min=1e-6)
scheduler_cls = warmup_scheduler.GradualWarmupScheduler(optimizer_cls, multiplier=1.,
                                                    total_epoch=5, after_scheduler=lr_scheduler_cls)


optimizer_mem = optim.AdamW(model.parameters(), lr=1e-3,)
lr_scheduler_mem = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_mem,
                                                          T_max=iterations,
                                                              eta_min=1e-5)
scheduler_mem = warmup_scheduler.GradualWarmupScheduler(optimizer_mem, multiplier=1.,
                                                    total_epoch=5, after_scheduler=lr_scheduler_mem)


In [14]:
save_path = f'./models/cifar10_vit_{max_train_samples}_{hidden}width_{mlp_hidden}mlp_dim_{num_layers}layers_split_forwar_regularized.pt'
save_path

'./cifar_vit/cifar10_vit_1024_384width_1152mlp_dim_7layers_split_forwar_regularized.pt'

In [15]:
for epoch in range(iterations):
    loss_c = 0
    loss_r = 0
    loss = 0
    c=0
    for (code, _, imgs), (data, labels) in zip(train_loader_mem,
                                  train_loader_cls):
        data = data.to(device)
        code = code.to(device)
        imgs = imgs.to(device)
        labels = labels.to(device)
        

        optimizer_cls.zero_grad()
        optimizer_mem.zero_grad()
        predlabel = model(data)
        loss_classf = CE(predlabel,
                         labels)
        loss_classf.backward()   
        optimizer_cls.step()
        
        optimizer_mem.zero_grad()
        optimizer_cls.zero_grad()
        predimg = model.forward_transposed(code)
        loss_recon = MSE(predimg, imgs)
        loss_recon.backward()
        optimizer_mem.step()

        # add the mini-batch training loss to epoch loss
        loss_c += loss_classf.item()
        loss_r += loss_recon.item()
        c+=1
    
    # compute the epoch training loss
    scheduler_cls.step()
    scheduler_mem.step()
    # display the epoch training loss
    print("Iteration : {}/{}, loss_c = {:.6f}, loss_r = {:.6f}".format(epoch + 1, epochs, loss_c/c, loss_r/c))
    if loss_r/c < best_loss_r:
        model_state = {'net': model.state_dict(),
                       'opti_mem': optimizer_mem.state_dict(), 
                       'opti_cls': optimizer_cls.state_dict(), 
                       'loss_r': loss_r/c}
        torch.save(model_state, save_path)
        best_loss_r = loss_r/c

epoch : 580/10000, loss_c = 2.065230, loss_r = 0.039051
epoch : 581/10000, loss_c = 2.070234, loss_r = 0.040514
epoch : 582/10000, loss_c = 2.063073, loss_r = 0.039924
epoch : 583/10000, loss_c = 2.056525, loss_r = 0.040125
epoch : 584/10000, loss_c = 2.058969, loss_r = 0.039740
epoch : 585/10000, loss_c = 2.067906, loss_r = 0.039862
epoch : 586/10000, loss_c = 2.060395, loss_r = 0.039397
epoch : 587/10000, loss_c = 2.066046, loss_r = 0.038535
epoch : 588/10000, loss_c = 2.062440, loss_r = 0.038421
epoch : 589/10000, loss_c = 2.059731, loss_r = 0.037518
epoch : 590/10000, loss_c = 2.062116, loss_r = 0.037658
epoch : 591/10000, loss_c = 2.060669, loss_r = 0.037390
epoch : 592/10000, loss_c = 2.062395, loss_r = 0.037605
epoch : 593/10000, loss_c = 2.071350, loss_r = 0.037350
epoch : 594/10000, loss_c = 2.078034, loss_r = 0.038064
epoch : 595/10000, loss_c = 2.056024, loss_r = 0.038951
epoch : 596/10000, loss_c = 2.051270, loss_r = 0.038417
epoch : 597/10000, loss_c = 2.054557, loss_r = 0

KeyboardInterrupt: 

In [16]:
model.load_state_dict(torch.load(save_path)['net'])
# optimizer.load_state_dict(torch.load(save_path)['opti'])
torch.load(save_path)['loss_r']

0.002482061589641186

In [18]:
correct=0
total = 0
model.eval()
with torch.no_grad():
    for (inputs, labels) in test_loader:
        code = torch.zeros(inputs.size(0), 10, device=device)
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = model(inputs)
        ypred = output.max(dim=1, keepdim=True)[1].squeeze(1)
        correct += ypred.eq(labels).sum()
        total += ypred.size(0)
print("Acc", correct/total)

Acc tensor(0.8108, device='cuda:0')


In [22]:
model.eval()
error_list = []
recon_list = []
org_list = []
label_list = []

with torch.no_grad():
    for codes, labels, imgs in train_loader_mem:
        imgs = imgs.to(device)
        imgrecon = model.forward_transposed(codes)
        error = ((imgs - imgrecon)**2).sum(dim=(1,2,3))/(3*32*32)
        error_list.append(error.cpu().numpy())
        recon_list.append(imgrecon.cpu())
        org_list.append(imgs.cpu())
        label_list.append(labels.cpu().numpy())
error_list = np.concatenate(error_list)
recon_list = torch.cat(recon_list, axis=0)
org_list = torch.cat(org_list, axis=0)
label_list = np.concatenate(label_list)