In [4]:
import torch 
from torchvision import models
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from torch import nn, optim
from sklearn.metrics import cohen_kappa_score
from dataset import RetinaDataset
import config

In [2]:
from utils import check_accuracy, make_prediction

In [3]:
def train_one_epoch(loader, model, optimizer, loss_fn, scaler, device):
    losses = []
    loop = tqdm(loader)
    for batch_idx, (x, y, _) in enumerate(loop):
        x.to(device)
        y.to(device)

        with torch.cuda.amp.autocast_mode():
            scores = model(x)
            loss = loss_fn(scores, y)

        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
def main():
    train_dataset = RetinaDataset(img_folder="adfasdfasdf", path_to_csv="asdfasfe", transform=config.train_transform)
    val_dataset = RetinaDataset(img_folder="asdfadfasdf", path_to_csv="asdfasdf", train=False, transform=config.val_transform)
    test_dataset = RetinaDataset(img_folder="asdfasdfas", path_to_csv="asdfasdfasd", train=False, transform=config.val_transform)

    train_loader = DataLoader(train_dataset, config.BATCH_SIZE, True, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY)
    val_loader = DataLoader(val_dataset, config.BATCH_SIZE, False, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY)
    test_loader = DataLoader(val_dataset, config.BATCH_SIZE, False, config.NUM_WORKERS, config.PIN_MEMORY)

    loss_fn = nn.CrossEntropyLoss()
    model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
    model._fc = nn.Linear(1536, 5)
    model.to(config.DEVICE)
    optimizer = optim.Adam(params=model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
    scaler = torch.cuda.amp.grad_scaler()

    for epoch in range(config.EPOCHS):
        train_one_epoch(train_loader, model, optimizer, loss_fn, scaler, config.DEVICE)
        preds, labels = check_accuracy(val_loader, model, config.DEVICE)

        print(f"validation: {cohen_kappa_score(labels, preds, weights='quadratic')}") 

    make_prediction(model, test_loader)