In [1]:
import sys
sys.path.append("../pytorch-transformers-implementation/")

In [2]:
from copy import deepcopy

In [41]:
from vision_transformer.dataset.iterator import ImageNetIterator
from torch.utils.data import DataLoader

In [42]:
train_iter = ImageNetIterator(root='../imagenet/',
                 height=256,
                 width=256,
                 is_train=True,
                 in_memory=False,
                 verbose=False)

test_iter = ImageNetIterator(root='../imagenet/',
                 height=256,
                 width=256,
                 is_train=True,
                 in_memory=False,
                 verbose=False)

train_loader = DataLoader(train_iter, batch_size=64*2, shuffle=True, num_workers=10)
valid_loader = DataLoader(test_iter, batch_size=64*2, shuffle=False, num_workers=10)

In [43]:
import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self,
                 height,
                 width,
                 channel,
                 patch,
                 C):
        super().__init__()
        self.C = C

        self.patch_size = patch ** 2
        img_size = height * width
        assert img_size % self.patch_size == 0, 'img is not divisible with patch'

        self.seq_length = img_size // self.patch_size
        input_dim = self.patch_size * channel
        self.patch_emb = nn.Linear(input_dim, C)

    def forward(self, img):
        N, C, H, W = img.shape

        splitted = img.view(N, C, -1).split(self.patch_size, -1)  # [N, C, H*W]
        stacked_tensor = torch.stack(splitted, dim=2)  # [N, C, (H*W)/(P**2), P**2]

        stacked_tensor = stacked_tensor.permute(0, 2, 1, 3).contiguous()  # [N, (H*W)/(P**2), C, P**2]
        stacked_tensor = stacked_tensor.view(N, stacked_tensor.shape[1], -1)  # [N, (H*W)/(P**2), C * P**2]
        # S(sequence length) : (H*W)/(P**2)

        embeddings = self.patch_emb(stacked_tensor)  # [N, S, C]        
        return embeddings

In [44]:
for (x,y) in train_loader : 
    break

In [45]:
h, w, c = 256, 256, 3
p = 32
C=  256

pe = PatchEmbedding(h, w, c, p, C)
emb = pe(x)
S = emb.shape[1]

In [46]:
class Mixer(nn.Module) : 
    def __init__(self, input_dim, hidden_dim) :
        super().__init__()
        self.w1 = nn.Linear(input_dim, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, input_dim)
        self.ln = nn.LayerNorm(input_dim)
        self.act = nn.GELU()
        
    def forward(self, x) : 
        normalized_x = self.ln(x)
        projected = self.w1(normalized_x)
        activated = self.act(projected)
        return x + self.w2(activated)
    
class MixerBlock(nn.Module) : 
    def __init__(self, channel_mixer, token_mixer) :
        super().__init__()
        self.cm = channel_mixer
        self.tm = token_mixer
        
    def forward(self, x) : 
        u = self.cm(x)
        ut = u.permute(0,2,1).contiguous()
        y = self.tm(ut)
        return y.permute(0,2,1).contiguous()
    
class MM(nn.Module) : 
    def __init__(self, 
                 height,
                 width,
                 channel,
                 num_layers,
                 patch_size,
                 C,
                 S,
                 D_c,
                 D_s,
                 output_dim) :
        super().__init__()
        
        self.pe = PatchEmbedding(height,
                            width,
                            channel,
                            patch_size,
                            C)
        
        channel_mixer = Mixer(C, D_c)
        token_mixer = Mixer(S, D_s)
        mixer = MixerBlock(channel_mixer, token_mixer)
        
        self.encoders = nn.ModuleList([deepcopy(mixer) for _ in range(num_layers)])
        self.fc = nn.Linear(C, output_dim)
        
    def forward(self, x) : 
        emb = self.pe(x)
        for enc in self.encoders : 
            emb = enc(emb)
            
        return self.fc(emb.mean(1))

In [55]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model = nn.DataParallel(MM(h, w, c, 6, p, C, S, 2048, 256, len(train_iter.label_dict)).to(device))

In [56]:
import torch.optim as optim
from tqdm.notebook import tqdm

In [57]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [58]:
def train() : 
    model.train()
    losses = []
    accuracies = []

    for x,y in tqdm(train_loader, desc='train') : 
        pred = model(x.to(device))

        optimizer.zero_grad()
        loss = criterion(pred, y.to(device))
        loss.backward()
        optimizer.step()

        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

def evalulate() : 
    model.eval()
    losses = []
    accuracies = []

    for x,y in tqdm(valid_loader, desc='valid') : 
        pred = model(x.to(device))

        loss = criterion(pred, y.to(device))
        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

In [None]:
epoches = 20

for proc in range(epoches) : 
    t_acc, t_loss = train()
    v_acc, v_loss = evalulate()
    print(f"""
                === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Train Acc : {round(t_acc, 3)}
        Valid Loss : {round(v_loss, 3)} | Valid Acc : {round(v_acc, 3)}
        
        ============================================
        ============================================
    """)

train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 1th Epoch ===
    
        Train Loss : 3.89 | Train Acc : 0.143
        Valid Loss : 3.412 | Valid Acc : 0.211
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 2th Epoch ===
    
        Train Loss : 3.244 | Train Acc : 0.239
        Valid Loss : 2.925 | Valid Acc : 0.292
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 3th Epoch ===
    
        Train Loss : 2.966 | Train Acc : 0.289
        Valid Loss : 2.652 | Valid Acc : 0.346
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 4th Epoch ===
    
        Train Loss : 2.725 | Train Acc : 0.333
        Valid Loss : 2.35 | Valid Acc : 0.405
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]