In [1]:
import timm
import torch
import pandas as pd
import os
import multiprocessing as mp

from torchvision import transforms as T
from cassava_leaf_disease import CassavaLeafDiseaseDataset
from sklearn import model_selection as ms
from tqdm import tqdm

In [5]:
lr = 1e-3
bs = 8
best_model = 0.0
train_history = []
val_history = []
device = torch.device("cuda:0")

In [2]:
def train_one_epoch(net, criterion, optimizer, train_loader):
    epoch_loss = 0.0
    epoch_acc = 0.0
    total = 0
    correct = 0
    
    progress = tqdm(train_loader, total=len(train_loader))
    for i, data in enumerate(progress, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)

        loss.backward()

        epoch_loss += loss.item()

        optimizer.step()
        progress.set_postfix(loss=(loss.item()))
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = epoch_loss / len(train_loader)
    epoch_acc = correct / total
    
    print('epoch loss ', epoch_loss, ' epoch acc ', epoch_acc)
    train_history.append((epoch_loss, epoch_acc))

In [3]:
def val_one_epoch(net, val_loader, epoch):
    epoch_loss = 0.0
    epoch_acc = 0.0
    total = 0
    correct = 0
    
    progress = tqdm(val_loader, total=len(val_loader))
    with torch.no_grad():
        for i, data in enumerate(tqdm(val_loader), 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)

            loss = criterion(outputs, labels)
       
            epoch_loss += loss.item()

            progress.set_postfix(loss=(loss.item()))

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_acc = correct / total
        val_history.append((epoch_loss, epoch_acc))
        
        print('epoch loss ', epoch_loss / len(train_loader), ' epoch acc ', epoch_acc)
        
        torch.save(vit16.state_dict(), './models/%s.pth'.format(epoch))
        
        if best_model < epoch_acc:
            best_model = epoch_acc
            torch.save(vit16.state_dict(), './models/best.pth')

In [4]:
data_root = '/newDriver/nam/cassava-leaf-disease'
working_root = './'

In [6]:
transform = T.Compose([T.ToTensor(),
                       T.Resize((384, 384)),
                       T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

df = pd.read_csv(os.path.join(data_root, 'train.csv'))
train_df, val_df = ms.train_test_split(df, test_size=0.2, random_state=42, stratify=df.label.values)

train_dataset = CassavaLeafDiseaseDataset(data_root, df=train_df, transform=transform)
val_dataset = CassavaLeafDiseaseDataset(data_root, df=val_df, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=mp.cpu_count())
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=mp.cpu_count())

In [7]:
vit_model_names = timm.list_models('vit*')
vit16 = timm.create_model('vit_base_patch16_384', pretrained=True, num_classes=5).to(device)

In [8]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vit16.parameters(), lr=1e-3)

In [None]:
for epoch in range(10):
    train_one_epoch(vit16, criterion, optimizer, train_loader)
    val_one_epoch(vit16, val_loader, epoch)

 98%|█████████▊| 2090/2140 [16:13<00:23,  2.13it/s, loss=0.845]