In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

### Dataset preparation:

- Read train dataset
- Create validation dataset
- Create Pytorch custom dataset preparation class
- Create Pytorch dataloaders



In [2]:
train = pd.read_csv('/kaggle/input/digit-recognizer/train.csv')

In [3]:
train.head()

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
train.shape

(42000, 785)

In [5]:
train_df , val_df = train_test_split(train,test_size=0.2 , random_state=0,shuffle=True)

In [6]:
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

In [7]:
class CustomDataset(Dataset):
    def __init__(self , data , dataset_type):
        self.images = data.iloc[:,1:].values
        self.labels = data.iloc[: , 0].values
        self.dataset_type = dataset_type
        self.train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])])
        self.valid_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,idx):
        image = self.images[idx].reshape(28,28).astype(np.uint8)
        label = self.labels[idx]
        if self.dataset_type=='train':
            image = self.train_transform(image)
        else:
            image = self.valid_transform(image)
        
        return image,label
            

In [8]:
train_ds = CustomDataset(train_df, 'train')
valid_ds = CustomDataset(val_df, 'valid')

In [9]:
len(train_ds) , len(valid_ds)

(33600, 8400)

In [10]:
np.unique(train_ds.labels)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [11]:
BATCH_SIZE = 512
train_dataloader = DataLoader(dataset=train_ds,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

val_dataloader = DataLoader(dataset=valid_ds,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

In [12]:
for image,label in train_dataloader:
    print(image.shape)
    print(label.shape)
    break

torch.Size([512, 1, 28, 28])
torch.Size([512])


### Vision Transformer

In [13]:
import torch
from torch import nn


#### Parameters

In [14]:
NUM_CLASSES = 10
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768
ACTIVATION="gelu"
NUM_ENCODERS = 8
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 16
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 49

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

#### PatchEmbedding class

In [15]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ),                  
            nn.Flatten(2))

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
        self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x 
        x = self.dropout(x)
        return x
    
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)

In [16]:
image.shape

torch.Size([512, 1, 28, 28])

In [17]:
model.eval()
patched = model(image.to(device))

In [18]:
patched.shape

torch.Size([512, 50, 16])

#### Vision Transformer

In [19]:
class VissionTransformer(nn.Module):
    def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, 
                 num_heads, hidden_dim, dropout, activation, in_channels):
        super().__init__()
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, 
                                                   activation=activation, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)    # Patch embedding blocks
        x = self.encoder_blocks(x)     # Encoder blocker with Multi-head attention
        x = self.mlp_head(x[:, 0, :])  # Apply classification on the CLS token only - MLP layer
        return x

In [20]:
model = VissionTransformer(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM,
                           NUM_ENCODERS, NUM_HEADS, HIDDEN_DIM,
                           DROPOUT, ACTIVATION, IN_CHANNELS).to(device)



In [21]:
model.eval()
mpl_output = model(image.to(device))

In [22]:
mpl_output.shape

torch.Size([512, 10])

In [23]:
mpl_output[0]

tensor([ 0.4043, -0.1069, -0.8049,  0.9896,  0.0604, -1.1779, -0.1905,  0.4683,
         0.5640, -0.6701], device='cuda:0', grad_fn=<SelectBackward0>)

### Loss function and Optimizer

In [24]:
# loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Model Training

In [25]:
EPOCHS = 100 ## run the model for 100 epochs

for epoch in tqdm(range(EPOCHS)): 
    train_loss, valid_loss = [], []
    ## training part 
    model.train()
    for images, labels in train_dataloader:
        
        #moving data into the device
        images , labels = images.to(device) , labels.to(device)
        
        optimizer.zero_grad()
        ## 1. forward propagation
        outputs = model(images)
        
        ## 2. loss calculation
        loss = criterion(outputs, labels)
        
        ## 3. backward propagation
        loss.backward()
        
        ## 4. weight optimization
        optimizer.step()
        
        train_loss.append(loss.item())
        
    ## evaluation part
    with torch.no_grad():
        model.eval()
        for images, labels in val_dataloader:
            images , labels = images.to(device) , labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            valid_loss.append(loss.item())
    print ("Epoch:", epoch, "Training Loss: ", np.mean(train_loss), "Valid Loss: ", np.mean(valid_loss))

  1%|          | 1/100 [00:20<33:47, 20.48s/it]

Epoch: 0 Training Loss:  2.3074351621396616 Valid Loss:  2.2789344647351433


  2%|▏         | 2/100 [00:40<33:14, 20.35s/it]

Epoch: 1 Training Loss:  2.0736881295839944 Valid Loss:  1.8663071534212898


  3%|▎         | 3/100 [01:00<32:36, 20.17s/it]

Epoch: 2 Training Loss:  1.530078933094487 Valid Loss:  1.2056716820772957


  4%|▍         | 4/100 [01:20<32:12, 20.13s/it]

Epoch: 3 Training Loss:  1.030517140121171 Valid Loss:  0.7714457126224742


  5%|▌         | 5/100 [01:40<31:48, 20.09s/it]

Epoch: 4 Training Loss:  0.71224056410067 Valid Loss:  0.5390141834230984


  6%|▌         | 6/100 [02:01<31:32, 20.13s/it]

Epoch: 5 Training Loss:  0.5409891447334578 Valid Loss:  0.39325466927360087


  7%|▋         | 7/100 [02:21<31:13, 20.15s/it]

Epoch: 6 Training Loss:  0.4253506976546663 Valid Loss:  0.3614175617694855


  8%|▊         | 8/100 [02:41<30:51, 20.12s/it]

Epoch: 7 Training Loss:  0.3887406849499905 Valid Loss:  0.3061814045204836


  9%|▉         | 9/100 [03:01<30:31, 20.12s/it]

Epoch: 8 Training Loss:  0.3381983244960958 Valid Loss:  0.2824194238466375


 10%|█         | 10/100 [03:21<30:08, 20.10s/it]

Epoch: 9 Training Loss:  0.3049949242761641 Valid Loss:  0.23561030976912556


 11%|█         | 11/100 [03:41<29:46, 20.08s/it]

Epoch: 10 Training Loss:  0.28319574599013186 Valid Loss:  0.20131764604764826


 12%|█▏        | 12/100 [04:01<29:23, 20.04s/it]

Epoch: 11 Training Loss:  0.24987637341925592 Valid Loss:  0.23283340913407943


 13%|█▎        | 13/100 [04:21<29:02, 20.03s/it]

Epoch: 12 Training Loss:  0.2387906556779688 Valid Loss:  0.21184723955743454


 14%|█▍        | 14/100 [04:41<28:43, 20.04s/it]

Epoch: 13 Training Loss:  0.22102367674762552 Valid Loss:  0.1783640849239686


 15%|█▌        | 15/100 [05:01<28:20, 20.01s/it]

Epoch: 14 Training Loss:  0.20318910682743246 Valid Loss:  0.154226756271194


 16%|█▌        | 16/100 [05:21<28:00, 20.01s/it]

Epoch: 15 Training Loss:  0.19320184218161035 Valid Loss:  0.1531410747591187


 17%|█▋        | 17/100 [05:41<27:42, 20.03s/it]

Epoch: 16 Training Loss:  0.18703130023045975 Valid Loss:  0.15204970188000622


 18%|█▊        | 18/100 [06:01<27:27, 20.09s/it]

Epoch: 17 Training Loss:  0.17726860037355713 Valid Loss:  0.13697503507137299


 19%|█▉        | 19/100 [06:21<27:07, 20.09s/it]

Epoch: 18 Training Loss:  0.17865307888749873 Valid Loss:  0.14563955498092315


 20%|██        | 20/100 [06:41<26:48, 20.11s/it]

Epoch: 19 Training Loss:  0.1657599126073447 Valid Loss:  0.14200054722673752


 21%|██        | 21/100 [07:02<26:27, 20.10s/it]

Epoch: 20 Training Loss:  0.1616513647816398 Valid Loss:  0.14421745477353826


 22%|██▏       | 22/100 [07:21<26:02, 20.03s/it]

Epoch: 21 Training Loss:  0.15216465813644003 Valid Loss:  0.1326837859609548


 23%|██▎       | 23/100 [07:41<25:41, 20.02s/it]

Epoch: 22 Training Loss:  0.154277023605325 Valid Loss:  0.13087759210782893


 24%|██▍       | 24/100 [08:02<25:24, 20.05s/it]

Epoch: 23 Training Loss:  0.1458489972759377 Valid Loss:  0.11473656138953041


 25%|██▌       | 25/100 [08:21<24:59, 19.99s/it]

Epoch: 24 Training Loss:  0.14276967005747737 Valid Loss:  0.1282327100634575


 26%|██▌       | 26/100 [08:41<24:36, 19.96s/it]

Epoch: 25 Training Loss:  0.13918274584593196 Valid Loss:  0.11745885102187886


 27%|██▋       | 27/100 [09:01<24:16, 19.95s/it]

Epoch: 26 Training Loss:  0.12814083006797414 Valid Loss:  0.11072445266387042


 28%|██▊       | 28/100 [09:21<23:54, 19.93s/it]

Epoch: 27 Training Loss:  0.12523015599810716 Valid Loss:  0.11547478419892929


 29%|██▉       | 29/100 [09:41<23:33, 19.91s/it]

Epoch: 28 Training Loss:  0.12422902328949986 Valid Loss:  0.11092601563124095


 30%|███       | 30/100 [10:01<23:15, 19.94s/it]

Epoch: 29 Training Loss:  0.11633546232725635 Valid Loss:  0.1141378634116229


 31%|███       | 31/100 [10:21<22:58, 19.97s/it]

Epoch: 30 Training Loss:  0.12198439579118382 Valid Loss:  0.10264060733949437


 32%|███▏      | 32/100 [10:41<22:45, 20.08s/it]

Epoch: 31 Training Loss:  0.11090980758043853 Valid Loss:  0.10017765751656364


 33%|███▎      | 33/100 [11:01<22:22, 20.04s/it]

Epoch: 32 Training Loss:  0.10835015965682088 Valid Loss:  0.09450072181575439


 34%|███▍      | 34/100 [11:21<21:59, 20.00s/it]

Epoch: 33 Training Loss:  0.10978094595625545 Valid Loss:  0.09777447756598978


 35%|███▌      | 35/100 [11:41<21:39, 20.00s/it]

Epoch: 34 Training Loss:  0.10706373699235194 Valid Loss:  0.11381940894267138


 36%|███▌      | 36/100 [12:01<21:21, 20.02s/it]

Epoch: 35 Training Loss:  0.10475393395983812 Valid Loss:  0.09637174054103739


 37%|███▋      | 37/100 [12:21<21:00, 20.01s/it]

Epoch: 36 Training Loss:  0.10478957635209415 Valid Loss:  0.09436714824508219


 38%|███▊      | 38/100 [12:41<20:38, 19.97s/it]

Epoch: 37 Training Loss:  0.1037956876398036 Valid Loss:  0.0935365003259743


 39%|███▉      | 39/100 [13:01<20:16, 19.94s/it]

Epoch: 38 Training Loss:  0.0987650566367489 Valid Loss:  0.0911190110970946


 40%|████      | 40/100 [13:21<19:56, 19.94s/it]

Epoch: 39 Training Loss:  0.1036397847488071 Valid Loss:  0.09840014796046649


 41%|████      | 41/100 [13:41<19:36, 19.95s/it]

Epoch: 40 Training Loss:  0.09251512705602428 Valid Loss:  0.08683148705784012


 42%|████▏     | 42/100 [14:01<19:14, 19.90s/it]

Epoch: 41 Training Loss:  0.0905681751442678 Valid Loss:  0.08110525700099327


 43%|████▎     | 43/100 [14:21<18:55, 19.92s/it]

Epoch: 42 Training Loss:  0.09001318624976909 Valid Loss:  0.09383908944094882


 44%|████▍     | 44/100 [14:41<18:38, 19.97s/it]

Epoch: 43 Training Loss:  0.08561638590287078 Valid Loss:  0.09360727580154643


 45%|████▌     | 45/100 [15:01<18:18, 19.97s/it]

Epoch: 44 Training Loss:  0.0893322318566568 Valid Loss:  0.09191804817494224


 46%|████▌     | 46/100 [15:21<18:00, 20.01s/it]

Epoch: 45 Training Loss:  0.08870207185320782 Valid Loss:  0.08942620127516634


 47%|████▋     | 47/100 [15:41<17:38, 19.97s/it]

Epoch: 46 Training Loss:  0.08590622872791508 Valid Loss:  0.08203226634684731


 48%|████▊     | 48/100 [16:01<17:16, 19.93s/it]

Epoch: 47 Training Loss:  0.08784733904582082 Valid Loss:  0.09289622241083313


 49%|████▉     | 49/100 [16:20<16:54, 19.89s/it]

Epoch: 48 Training Loss:  0.08240732561909792 Valid Loss:  0.0801737174830016


 50%|█████     | 50/100 [16:40<16:33, 19.88s/it]

Epoch: 49 Training Loss:  0.0835983408897212 Valid Loss:  0.07776887386160738


 51%|█████     | 51/100 [17:00<16:15, 19.90s/it]

Epoch: 50 Training Loss:  0.08213176967745478 Valid Loss:  0.07986840408514528


 52%|█████▏    | 52/100 [17:20<16:01, 20.02s/it]

Epoch: 51 Training Loss:  0.07651378513511384 Valid Loss:  0.08000511215890155


 53%|█████▎    | 53/100 [17:41<15:41, 20.04s/it]

Epoch: 52 Training Loss:  0.07811328593754407 Valid Loss:  0.07851962789016612


 54%|█████▍    | 54/100 [18:01<15:27, 20.17s/it]

Epoch: 53 Training Loss:  0.0765594674104994 Valid Loss:  0.09099153253962011


 55%|█████▌    | 55/100 [18:21<15:09, 20.22s/it]

Epoch: 54 Training Loss:  0.07881517533325788 Valid Loss:  0.07637968615573995


 56%|█████▌    | 56/100 [18:42<14:49, 20.21s/it]

Epoch: 55 Training Loss:  0.0735961738409418 Valid Loss:  0.08059857785701752


 57%|█████▋    | 57/100 [19:02<14:28, 20.20s/it]

Epoch: 56 Training Loss:  0.07390781571016167 Valid Loss:  0.07286213097326896


 58%|█████▊    | 58/100 [19:22<14:07, 20.17s/it]

Epoch: 57 Training Loss:  0.06852003655424624 Valid Loss:  0.0793792984503157


 59%|█████▉    | 59/100 [19:42<13:46, 20.16s/it]

Epoch: 58 Training Loss:  0.06980084723821192 Valid Loss:  0.07394098249428413


 60%|██████    | 60/100 [20:03<13:32, 20.32s/it]

Epoch: 59 Training Loss:  0.0729751346181288 Valid Loss:  0.07432378653217764


 61%|██████    | 61/100 [20:23<13:13, 20.34s/it]

Epoch: 60 Training Loss:  0.07289542714980515 Valid Loss:  0.08383134971646701


 62%|██████▏   | 62/100 [20:43<12:51, 20.30s/it]

Epoch: 61 Training Loss:  0.06871595848916155 Valid Loss:  0.07288329719620593


 63%|██████▎   | 63/100 [21:04<12:30, 20.29s/it]

Epoch: 62 Training Loss:  0.0672575207709363 Valid Loss:  0.0737360360867837


 64%|██████▍   | 64/100 [21:24<12:10, 20.29s/it]

Epoch: 63 Training Loss:  0.06588701680867058 Valid Loss:  0.06983178651289028


 65%|██████▌   | 65/100 [21:44<11:48, 20.25s/it]

Epoch: 64 Training Loss:  0.0623280086244146 Valid Loss:  0.08127072869854815


 66%|██████▌   | 66/100 [22:04<11:25, 20.16s/it]

Epoch: 65 Training Loss:  0.06381553315529317 Valid Loss:  0.07794346627505387


 67%|██████▋   | 67/100 [22:24<11:02, 20.06s/it]

Epoch: 66 Training Loss:  0.07221724335668665 Valid Loss:  0.07077800701646243


 68%|██████▊   | 68/100 [22:44<10:42, 20.07s/it]

Epoch: 67 Training Loss:  0.06394811663212198 Valid Loss:  0.07389223619418986


 69%|██████▉   | 69/100 [23:04<10:20, 20.02s/it]

Epoch: 68 Training Loss:  0.06327792852552551 Valid Loss:  0.07792774127686725


 70%|███████   | 70/100 [23:24<09:59, 19.97s/it]

Epoch: 69 Training Loss:  0.06716641120499733 Valid Loss:  0.07413562694016625


 71%|███████   | 71/100 [23:44<09:41, 20.04s/it]

Epoch: 70 Training Loss:  0.06038471966078787 Valid Loss:  0.07283327859990738


 72%|███████▏  | 72/100 [24:04<09:22, 20.11s/it]

Epoch: 71 Training Loss:  0.06266224113377658 Valid Loss:  0.07600619424791898


 73%|███████▎  | 73/100 [24:24<09:01, 20.04s/it]

Epoch: 72 Training Loss:  0.05761201849037951 Valid Loss:  0.07161419225089691


 74%|███████▍  | 74/100 [24:44<08:40, 20.04s/it]

Epoch: 73 Training Loss:  0.06191067648769328 Valid Loss:  0.07339521694709272


 75%|███████▌  | 75/100 [25:04<08:21, 20.05s/it]

Epoch: 74 Training Loss:  0.06021140848822666 Valid Loss:  0.07350004091858864


 76%|███████▌  | 76/100 [25:24<08:01, 20.08s/it]

Epoch: 75 Training Loss:  0.059480991920061184 Valid Loss:  0.07377837707891184


 77%|███████▋  | 77/100 [25:44<07:41, 20.05s/it]

Epoch: 76 Training Loss:  0.05794668008545131 Valid Loss:  0.060685089405845195


 78%|███████▊  | 78/100 [26:04<07:19, 19.98s/it]

Epoch: 77 Training Loss:  0.055728578375595986 Valid Loss:  0.07083893830285352


 79%|███████▉  | 79/100 [26:24<06:59, 19.99s/it]

Epoch: 78 Training Loss:  0.04988787890496579 Valid Loss:  0.07344494902474039


 80%|████████  | 80/100 [26:44<06:39, 19.97s/it]

Epoch: 79 Training Loss:  0.06227057048994483 Valid Loss:  0.07896064660128425


 81%|████████  | 81/100 [27:04<06:18, 19.92s/it]

Epoch: 80 Training Loss:  0.055846440154268887 Valid Loss:  0.07936962168006335


 82%|████████▏ | 82/100 [27:24<05:58, 19.91s/it]

Epoch: 81 Training Loss:  0.05037180173464797 Valid Loss:  0.07751607588108848


 83%|████████▎ | 83/100 [27:44<05:39, 20.00s/it]

Epoch: 82 Training Loss:  0.05112191792013067 Valid Loss:  0.07315877540146604


 84%|████████▍ | 84/100 [28:04<05:20, 20.04s/it]

Epoch: 83 Training Loss:  0.05123642304291328 Valid Loss:  0.06457340750185882


 85%|████████▌ | 85/100 [28:24<05:00, 20.06s/it]

Epoch: 84 Training Loss:  0.051882110394989 Valid Loss:  0.07770828595932792


 86%|████████▌ | 86/100 [28:44<04:40, 20.07s/it]

Epoch: 85 Training Loss:  0.05016442504005902 Valid Loss:  0.068197330350385


 87%|████████▋ | 87/100 [29:04<04:20, 20.04s/it]

Epoch: 86 Training Loss:  0.05062959698790854 Valid Loss:  0.07261530452353113


 88%|████████▊ | 88/100 [29:24<04:00, 20.03s/it]

Epoch: 87 Training Loss:  0.05868564404998765 Valid Loss:  0.0688866517123054


 89%|████████▉ | 89/100 [29:44<03:40, 20.05s/it]

Epoch: 88 Training Loss:  0.049425154755061325 Valid Loss:  0.07299960579942255


 90%|█████████ | 90/100 [30:04<03:20, 20.02s/it]

Epoch: 89 Training Loss:  0.04808143156608849 Valid Loss:  0.07116750496275284


 91%|█████████ | 91/100 [30:24<03:00, 20.01s/it]

Epoch: 90 Training Loss:  0.04888699045687011 Valid Loss:  0.07166871362749268


 92%|█████████▏| 92/100 [30:44<02:39, 19.98s/it]

Epoch: 91 Training Loss:  0.04721065480826479 Valid Loss:  0.06574484004693873


 93%|█████████▎| 93/100 [31:04<02:20, 20.04s/it]

Epoch: 92 Training Loss:  0.045722106669211025 Valid Loss:  0.07102178212474375


 94%|█████████▍| 94/100 [31:24<02:00, 20.03s/it]

Epoch: 93 Training Loss:  0.048662786936443866 Valid Loss:  0.07121110115857686


 95%|█████████▌| 95/100 [31:44<01:40, 20.04s/it]

Epoch: 94 Training Loss:  0.04906425912949172 Valid Loss:  0.08080859679509611


 96%|█████████▌| 96/100 [32:04<01:20, 20.07s/it]

Epoch: 95 Training Loss:  0.04610736535466982 Valid Loss:  0.06865326578126234


 97%|█████████▋| 97/100 [32:25<01:00, 20.13s/it]

Epoch: 96 Training Loss:  0.04482278064119093 Valid Loss:  0.06564648569945027


 98%|█████████▊| 98/100 [32:45<00:40, 20.07s/it]

Epoch: 97 Training Loss:  0.040725477801805195 Valid Loss:  0.06836360671064433


 99%|█████████▉| 99/100 [33:05<00:20, 20.09s/it]

Epoch: 98 Training Loss:  0.048382223837754944 Valid Loss:  0.07288752978338915


100%|██████████| 100/100 [33:25<00:00, 20.05s/it]

Epoch: 99 Training Loss:  0.044365706627793385 Valid Loss:  0.0688697697923464





In [26]:
torch.cuda.empty_cache()

### Inference using trained model

In [27]:
predictions = []
model.eval()

with torch.no_grad():
    for images, labels in tqdm(val_dataloader):
        images = images.to(device)
                
        outputs = model(images)
        predictions.extend([int(i) for i in torch.argmax(outputs, dim=1)])

100%|██████████| 17/17 [00:02<00:00,  6.75it/s]


In [28]:
from sklearn.metrics import accuracy_score
accuracy_score(val_df['label'].values,predictions)

0.09785714285714285