In [1]:
import torch
import numpy as np

from torch import nn
import torch.nn.functional as F

from model.vit import  ViT

In [2]:
from torchvision import transforms
from utils.data_loader import custom_dataloader, filename_to_tensor

tf = transforms.Compose([
    transforms.Resize([480,480]),
    transforms.ToTensor()
])

train_loader, val_loader, class_names = custom_dataloader("mushrooms_test", 8)
num_classes = len(class_names)
class_names

9
9
10
11
12
9
15
18
12


['Suillus',
 'Agaricus',
 'Entoloma',
 'Amanita',
 'Russula',
 'Lactarius',
 'Cortinarius',
 'Boletus',
 'Hygrocybe']

In [3]:
sample = train_loader['img_path'][0]
img_batch = filename_to_tensor(sample, tf)
img_batch.shape

torch.Size([8, 3, 480, 480])

In [4]:
model = ViT(
    image_size = 480,
    patch_size = 32,
    num_classes = num_classes,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

In [5]:
len(train_loader['img_path'])

11

In [9]:
preds = model(img_batch)
preds

tensor([[0.1277, 0.1120, 0.1453, 0.1286, 0.1106, 0.1043, 0.1052, 0.1517, 0.1491],
        [0.1523, 0.0998, 0.0953, 0.1315, 0.1058, 0.1090, 0.1349, 0.0987, 0.1147],
        [0.0928, 0.1642, 0.1148, 0.1757, 0.0934, 0.1545, 0.1506, 0.1242, 0.1485],
        [0.1347, 0.0977, 0.1619, 0.1108, 0.1770, 0.0953, 0.1162, 0.1463, 0.0793],
        [0.1193, 0.0894, 0.1265, 0.1096, 0.1474, 0.1054, 0.1102, 0.1379, 0.1259],
        [0.0934, 0.1469, 0.1126, 0.0882, 0.1372, 0.1233, 0.1407, 0.1187, 0.1823],
        [0.1298, 0.1648, 0.1149, 0.1182, 0.1148, 0.1630, 0.1211, 0.0904, 0.1089],
        [0.1501, 0.1252, 0.1287, 0.1374, 0.1139, 0.1451, 0.1211, 0.1321, 0.0913]],
       grad_fn=<SoftmaxBackward0>)

In [23]:
from torch import optim

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [24]:
loss_fn = nn.CrossEntropyLoss()  

In [25]:
actual = F.one_hot(train_loader['label'][0], 4).type(torch.float32)

In [26]:
from utils.train_loop import training_loop

n_epochs = 10
transform = tf
saved_path = 'weight'

training_loop(n_epochs, optimizer, model, loss_fn, train_loader, val_loader, transform, saved_path, eval_interval=3)

KeyboardInterrupt: 

: 