In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
import os 
import numpy as np
import cv2
from tqdm import tqdm

In [12]:
params = {'input_channels':3, 'dim':768, 'hidden_dim':3072, 'patch_size':16, 'img_size':224, 'num_layers':12, 
          'dropout':0.0, 'attention_dropout':0.0, 'num_heads':12, 'fine_tune':None, 'num_classes':10, 'encoder_norm':True, 'fc_norm':False}

In [13]:
from vit.network import *
from vit.model import *

In [14]:
model = VisionTransformer(**params)

In [15]:
import torchvision
from torchvision import transforms

In [16]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

In [17]:
data = torchvision.datasets.CIFAR10('D:/Data/random/', transform=transform_train, train=True, download=False)

In [18]:
train, val = torch.utils.data.random_split(data, [40000, 10000])
trainloader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=16, num_workers=8)
valloader = torch.utils.data.DataLoader(val, shuffle=True, batch_size=16, num_workers=8)

In [19]:
device = torch.device('cuda:0')

In [20]:
model.load_state_dict(torch.load('10omfg.pth', map_location=device))

<All keys matched successfully>

In [21]:
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.005)
loss = nn.CrossEntropyLoss()

In [22]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 2070 SUPER'

In [23]:
model.train()

VisionTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): Sequential(
    (0): EncoderBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): MultiheadAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): EncoderBlock(
      (norm1): LayerNorm((768,), eps=1e-06, ele

In [24]:
training_losses = []
val_losses = []
avg_training_losses = []
avg_val_losses = []

In [25]:
for e in tqdm(range(10)):
    for x, y in trainloader:
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        y_pred = model(x.float())
        l = loss(y_pred, y)
        l.backward()
        optimizer.step()

        npl = l.detach().cpu().numpy()
        training_losses.append(npl)

    for x, y in valloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x.float())
        l = loss(y_pred, y)

        npl = l.detach().cpu().numpy()
        val_losses.append(npl)

    avg_training_losses.append(np.mean(training_losses))
    avg_val_losses.append(np.mean(val_losses))

    print(f'Epoch: {e}, Training Loss: {np.mean(training_losses)}, Validation Loss: {np.mean(val_losses)}')

  2%|▉                                               | 1/50 [13:36<11:06:55, 816.65s/it]

Epoch: 0, Training Loss: 2.332289218902588, Validation Loss: 2.017587423324585


  4%|█▉                                              | 2/50 [27:18<10:55:45, 819.70s/it]

Epoch: 1, Training Loss: 2.1509275436401367, Validation Loss: 1.9676073789596558


  6%|██▉                                             | 3/50 [41:01<10:43:26, 821.41s/it]

Epoch: 2, Training Loss: 2.07155179977417, Validation Loss: 1.9438546895980835


  8%|███▊                                            | 4/50 [54:44<10:30:08, 821.93s/it]

Epoch: 3, Training Loss: 2.021153211593628, Validation Loss: 1.9118585586547852


 10%|████▌                                         | 5/50 [1:08:26<10:16:27, 821.94s/it]

Epoch: 4, Training Loss: 1.9796394109725952, Validation Loss: 1.8941690921783447


 12%|█████▌                                        | 6/50 [1:22:05<10:01:57, 820.85s/it]

Epoch: 5, Training Loss: 1.9402705430984497, Validation Loss: 1.8601536750793457


 14%|██████▌                                        | 7/50 [1:35:43<9:47:33, 819.85s/it]

Epoch: 6, Training Loss: 1.9198921918869019, Validation Loss: 1.8475416898727417


 16%|███████▌                                       | 8/50 [1:49:19<9:33:04, 818.68s/it]

Epoch: 7, Training Loss: 1.8996692895889282, Validation Loss: 1.8318746089935303


 18%|████████▍                                      | 9/50 [2:02:55<9:18:50, 817.82s/it]

Epoch: 8, Training Loss: 1.8873989582061768, Validation Loss: 1.8242069482803345


 20%|█████████▏                                    | 10/50 [2:16:31<9:04:50, 817.27s/it]

Epoch: 9, Training Loss: 1.8713998794555664, Validation Loss: 1.8182131052017212


 22%|██████████                                    | 11/50 [2:30:06<8:50:45, 816.56s/it]

Epoch: 10, Training Loss: 1.8570749759674072, Validation Loss: 1.8224866390228271


 24%|███████████                                   | 12/50 [2:43:42<8:37:02, 816.39s/it]

Epoch: 11, Training Loss: 1.8504129648208618, Validation Loss: 1.818852186203003


 26%|███████████▉                                  | 13/50 [2:58:25<8:35:57, 836.70s/it]

Epoch: 12, Training Loss: 1.8404426574707031, Validation Loss: 1.8048964738845825


 28%|████████████▉                                 | 14/50 [3:15:21<8:54:27, 890.75s/it]

Epoch: 13, Training Loss: 1.8269356489181519, Validation Loss: 1.8145118951797485


 30%|█████████████▊                                | 15/50 [3:32:34<9:04:42, 933.77s/it]

Epoch: 14, Training Loss: 1.8242048025131226, Validation Loss: 1.8075532913208008


 32%|██████████████▋                               | 16/50 [3:49:22<9:01:47, 956.10s/it]

Epoch: 15, Training Loss: 1.8342585563659668, Validation Loss: 1.8131433725357056


 34%|███████████████▋                              | 17/50 [4:06:18<8:55:44, 974.08s/it]

Epoch: 16, Training Loss: 1.8348431587219238, Validation Loss: 1.8130311965942383


 36%|████████████████▌                             | 18/50 [4:23:34<8:49:20, 992.52s/it]

Epoch: 17, Training Loss: 1.831372618675232, Validation Loss: 1.8063491582870483


 38%|█████████████████                            | 19/50 [4:40:51<8:39:42, 1005.87s/it]

Epoch: 18, Training Loss: 1.8269144296646118, Validation Loss: 1.799658179283142


 40%|██████████████████                           | 20/50 [4:57:49<8:24:51, 1009.71s/it]

Epoch: 19, Training Loss: 1.8186451196670532, Validation Loss: 1.7910112142562866


 42%|██████████████████▉                          | 21/50 [5:14:42<8:08:30, 1010.69s/it]

Epoch: 20, Training Loss: 1.8165481090545654, Validation Loss: 1.7942672967910767


 44%|████████████████████▏                         | 22/50 [5:28:16<7:24:06, 951.66s/it]

Epoch: 21, Training Loss: 1.817517638206482, Validation Loss: 1.7998852729797363


 46%|█████████████████████▏                        | 23/50 [5:41:51<6:49:45, 910.57s/it]

Epoch: 22, Training Loss: 1.8176735639572144, Validation Loss: 1.798032283782959


 48%|██████████████████████                        | 24/50 [5:55:24<6:21:56, 881.41s/it]

Epoch: 23, Training Loss: 1.8191008567810059, Validation Loss: 1.7998161315917969


 50%|███████████████████████                       | 25/50 [6:08:58<5:58:49, 861.19s/it]

Epoch: 24, Training Loss: 1.8179514408111572, Validation Loss: 1.7977267503738403


 52%|███████████████████████▉                      | 26/50 [6:22:30<5:38:32, 846.37s/it]

Epoch: 25, Training Loss: 1.8157929182052612, Validation Loss: 1.794926643371582


 54%|████████████████████████▊                     | 27/50 [6:36:03<5:20:35, 836.34s/it]

Epoch: 26, Training Loss: 1.8142690658569336, Validation Loss: 1.7932124137878418


 56%|█████████████████████████▊                    | 28/50 [6:49:38<5:04:18, 829.93s/it]

Epoch: 27, Training Loss: 1.8119425773620605, Validation Loss: 1.7905657291412354


 58%|██████████████████████████▋                   | 29/50 [7:03:11<4:48:41, 824.83s/it]

Epoch: 28, Training Loss: 1.8104405403137207, Validation Loss: 1.7909510135650635


 60%|███████████████████████████▌                  | 30/50 [7:16:44<4:33:48, 821.43s/it]

Epoch: 29, Training Loss: 1.8078068494796753, Validation Loss: 1.789329171180725


 62%|████████████████████████████▌                 | 31/50 [7:30:17<4:19:14, 818.66s/it]

Epoch: 30, Training Loss: 1.8047668933868408, Validation Loss: 1.7867907285690308


 64%|█████████████████████████████▍                | 32/50 [7:43:52<4:05:17, 817.66s/it]

Epoch: 31, Training Loss: 1.8012275695800781, Validation Loss: 1.7840077877044678


 66%|██████████████████████████████▎               | 33/50 [7:57:23<3:51:07, 815.72s/it]

Epoch: 32, Training Loss: 1.7978441715240479, Validation Loss: 1.779195785522461


 68%|███████████████████████████████▎              | 34/50 [8:10:56<3:37:18, 814.89s/it]

Epoch: 33, Training Loss: 1.7938005924224854, Validation Loss: 1.7773278951644897


 70%|████████████████████████████████▏             | 35/50 [8:24:31<3:23:43, 814.88s/it]

Epoch: 34, Training Loss: 1.7915996313095093, Validation Loss: 1.7736314535140991


 72%|█████████████████████████████████             | 36/50 [8:38:06<3:10:08, 814.86s/it]

Epoch: 35, Training Loss: 1.7876378297805786, Validation Loss: 1.76983642578125


 74%|██████████████████████████████████            | 37/50 [8:51:38<2:56:24, 814.18s/it]

Epoch: 36, Training Loss: 1.7834199666976929, Validation Loss: 1.7657229900360107


 76%|██████████████████████████████████▉           | 38/50 [9:05:14<2:42:55, 814.62s/it]

Epoch: 37, Training Loss: 1.7826164960861206, Validation Loss: 1.7649911642074585


 78%|███████████████████████████████████▉          | 39/50 [9:18:46<2:29:11, 813.80s/it]

Epoch: 38, Training Loss: 1.7812825441360474, Validation Loss: 1.7634795904159546


 80%|████████████████████████████████████▊         | 40/50 [9:32:17<2:15:29, 812.96s/it]

Epoch: 39, Training Loss: 1.778870940208435, Validation Loss: 1.762717604637146


 82%|█████████████████████████████████████▋        | 41/50 [9:45:51<2:02:00, 813.33s/it]

Epoch: 40, Training Loss: 1.776458501815796, Validation Loss: 1.760607361793518


 84%|██████████████████████████████████████▋       | 42/50 [9:59:25<1:48:29, 813.64s/it]

Epoch: 41, Training Loss: 1.7735824584960938, Validation Loss: 1.7586801052093506


 86%|██████████████████████████████████████▋      | 43/50 [10:12:57<1:34:50, 812.98s/it]

Epoch: 42, Training Loss: 1.770975112915039, Validation Loss: 1.7564071416854858


 88%|███████████████████████████████████████▌     | 44/50 [10:26:27<1:21:13, 812.23s/it]

Epoch: 43, Training Loss: 1.7687928676605225, Validation Loss: 1.7577537298202515


 90%|████████████████████████████████████████▌    | 45/50 [10:39:56<1:07:36, 811.26s/it]

Epoch: 44, Training Loss: 1.7672713994979858, Validation Loss: 1.7558780908584595


 92%|███████████████████████████████████████████▏   | 46/50 [10:53:26<54:03, 810.88s/it]

Epoch: 45, Training Loss: 1.7670342922210693, Validation Loss: 1.7543842792510986


 94%|████████████████████████████████████████████▏  | 47/50 [11:06:55<40:30, 810.15s/it]

Epoch: 46, Training Loss: 1.765127420425415, Validation Loss: 1.753571629524231


 96%|█████████████████████████████████████████████  | 48/50 [11:20:23<26:59, 809.57s/it]

Epoch: 47, Training Loss: 1.7629648447036743, Validation Loss: 1.7509987354278564


 98%|██████████████████████████████████████████████ | 49/50 [11:33:53<13:29, 809.73s/it]

Epoch: 48, Training Loss: 1.7614150047302246, Validation Loss: 1.74971604347229


100%|███████████████████████████████████████████████| 50/50 [11:47:23<00:00, 848.86s/it]

Epoch: 49, Training Loss: 1.7593122720718384, Validation Loss: 1.747793197631836





In [26]:
torch.save(model, '50epoch.pth')

In [27]:
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): Sequential(
    (0): EncoderBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): MultiheadAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): EncoderBlock(
      (norm1): LayerNorm((768,), eps=1e-06, ele

In [46]:
params = {'input_channels':3, 'dim':768, 'hidden_dim':3072, 'patch_size':16, 'img_size':224, 'num_layers':12, 
          'dropout':0.1, 'attention_dropout':0.1, 'num_heads':12, 'fine_tune':None, 'num_classes':10, 'encoder_norm':True, 'fc_norm':False}

In [47]:
model2 = VisionTransformer(**params)

In [48]:
model2 = model2.to(device)

In [49]:
model2.load_state_dict(torch.load('10omfg.pth', map_location=device))

<All keys matched successfully>

In [57]:
import timm
model2 = timm.create_model('vit_base_patch16_224_in21k', pretrained=True, num_classes=10)
model2 = model2.to(device)

In [53]:
a, b = [], []

for x, y in valloader:
    x, y = x.to(device), y.to(device)
    x = x.float()
    y_pred = model2(x)
    y_pred, y = y_pred.detach().cpu().numpy(), y.detach().cpu().numpy()
    yes = np.argmax(y_pred, axis=1)
    a.append(yes)
    b.append(y)

In [54]:
count = 0
for i, j in zip(a, b):
    for k, l in zip(i, j):
        if k == l:
            count += 1
count/10000

0.0

In [None]:
# def train(model, optimizer, loss, epochs, trainloader, valloader, device):
#     avg_tloss, avg_vloss = [], []
#     for e in tqdm(range(epochs)):
#         t_loss, v_loss = [], []
#         for x, y in trainloader:
#             optimizer.zero_grad()
#             x, y = x.to(device), y.to(device)
#             y_pred = model(x.float())
#             l = loss(y_pred, y)
#             l.backward()
#             optimizer.step()
            
#             npl = l.detach().cpu().numpy()
#             t_loss.append(npl)
            
#         with torch.no_grad():
#             for x, y in valloader:
#                 x, y = x.to(device), y.to(device)
#                 y_pred = model(x.float())
#                 l = loss(y_pred, y)
                
#                 npl = l.detach().cpu().numpy()
#                 v_loss.append(npl)
    
#         avg_tloss.append(np.mean(t_loss))
#         avg_vloss.append(np.mean(v_loss))

#         print(f'Epoch: {e}, Training Loss: {np.mean(t_loss)}, Validation Loss: {np.mean(v_loss)}')
    
#     return avg_tloss, avg_vloss, model, optimizer

In [None]:
# tloss, vloss, model, optimizer = train(model, optimizer, loss, 100, trainloader, valloader, device)