In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import numpy as np
from tqdm import tqdm

import datasets_classification as dataset
import resnet

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
images_dir = "../../data/vinbig_data/"
# get datasets
train_set = dataset.get_train_data(images_dir)
val_set = dataset.get_val_data(images_dir)
test_set = dataset.get_test_data(images_dir)

# setup dataloaders
batch_size = 16
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_set, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_set, batch_size = 1, shuffle = True)

In [13]:
# load model and setup training parameters
model = resnet.ResNet_18(image_channels=1, num_classes=15)
model = model.to(device)

epochs = 100
val_check = 10
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)

In [14]:
# train and eval the model
def save_ckpt(model, path):
    torch.save(model.state_dict(), path) 
    return

def get_metric(pred, target, thresh=0.5):
    pred = np.array(pred > thresh, dtype=float)
    return f1_score(y_true = target, y_pred = pred, average='samples')

def test(model, test_dataloader):
    """
    model = resnet.ResNet_18(image_channels=1, labels=10)
    weights = torch.load(model_path)
    model.load_state_dict(weights)
    """
    model = model.to(device)
    macro_f1 = []
    model.eval()
    with torch.no_grad():
        for data, targets in test_dataloader:
            data = data.to(device)
            targets = targets.type(torch.int8).to(device)

            pred = model(data)
            macro_f1.append(get_metric(pred.cpu(), targets.cpu()))
    
    return np.mean(macro_f1)

def validate(model, val_dataloader):
    val_losses = []
    
    model.eval()
    for (data, targets) in val_dataloader:
        data = data.to(device)
        targets = targets.to(device)

        pred = model(data)
        loss = criterion(pred.type(torch.float), targets.type(torch.float))
        val_losses.append(loss.item())
    
    return np.mean(val_losses)

def train(model, train_dataloader, val_dataloader, save_path):
    best_loss = np.inf
    for epoch in tqdm(range(epochs)):
        batch_losses = []
        for (imgs, targets) in train_dataloader:
            imgs, targets = imgs.to(device), targets.to(device)

            optimizer.zero_grad()

            pred = model(imgs)
            loss = criterion(pred.type(torch.float), targets.type(torch.float))

            batch_loss_value = loss.item()
            loss.backward()
            optimizer.step()

            batch_losses.append(batch_loss_value)

        loss = np.mean(batch_losses)
        print(f"Epoch: {epoch}")
        print(f"[TRAIN] Loss: {loss:3f}")
        
        if epoch % val_check == 0:
            val_loss = validate(model, val_dataloader)
        
            print(f"[VALIDATION] Loss: {val_loss:3f}\n")

            model.train()
            if val_loss < best_loss:
                best_loss = val_loss
                save_ckpt(model, save_path)

In [None]:
train(model, train_dataloader, val_dataloader, "model_classification_resnet_no_aug")

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 0
[TRAIN] Loss: 0.347970


  1%|          | 1/100 [00:32<52:48, 32.01s/it]

[VALIDATION] Loss: 4.551671



  2%|▏         | 2/100 [00:56<44:53, 27.48s/it]

Epoch: 1
[TRAIN] Loss: 0.284181


  3%|▎         | 3/100 [01:19<40:58, 25.34s/it]

Epoch: 2
[TRAIN] Loss: 0.269884


  4%|▍         | 4/100 [01:41<38:58, 24.36s/it]

Epoch: 3
[TRAIN] Loss: 0.266788


  5%|▌         | 5/100 [02:04<37:25, 23.64s/it]

Epoch: 4
[TRAIN] Loss: 0.276969


  6%|▌         | 6/100 [02:26<36:17, 23.17s/it]

Epoch: 5
[TRAIN] Loss: 0.272311


  7%|▋         | 7/100 [02:47<34:57, 22.55s/it]

Epoch: 6
[TRAIN] Loss: 0.262607


  8%|▊         | 8/100 [03:10<34:34, 22.55s/it]

Epoch: 7
[TRAIN] Loss: 0.260606


  9%|▉         | 9/100 [03:32<34:07, 22.50s/it]

Epoch: 8
[TRAIN] Loss: 0.265112


 10%|█         | 10/100 [03:55<33:38, 22.42s/it]

Epoch: 9
[TRAIN] Loss: 0.260651
Epoch: 10
[TRAIN] Loss: 0.253487


 11%|█         | 11/100 [04:25<37:05, 25.01s/it]

[VALIDATION] Loss: 0.304771

