In [None]:
import random
import os
import torch
import wandb

import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b0
from tqdm.notebook import tqdm

from data_utils import CustomDataset, get_transforms
from training import train

In [None]:
def seed_everything(seed: int):    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')


config = dict(
    seed=0xbebebe,
    lr=0.001,
    n_epochs=24,
    batch_size=64,
    device=device,
    log_iters=200,
    augmentations=get_transforms('train'),
)
run = wandb.init(
    project='DL01-XRay',
    config=config,
)
seed_everything(wandb.config.seed)

cfg = wandb.config
model = efficientnet_b0(num_classes=5)
model.to(device)
train_dataset = CustomDataset('data/dev_train.csv', get_transforms('train'))
val_dataset = CustomDataset('data/dev_val.csv', get_transforms('val'))
dataloaders = dict(
    train=torch.utils.data.DataLoader(train_dataset, cfg.batch_size, shuffle=True, pin_memory=True, num_workers=8),
    val=torch.utils.data.DataLoader(val_dataset, cfg.batch_size, shuffle=False, num_workers=8)
)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), cfg.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=cfg.lr,
    pct_start=0.2,
    total_steps=cfg.n_epochs * (len(train_dataset.data) // cfg.batch_size + 1)
)

train(
    model=model,
    dataloaders=dataloaders,
    optimizer=optimizer,
    criterion=criterion,
    n_epochs=cfg.n_epochs,
    scheduler=scheduler,
    device=device,
)

run.finish()

In [None]:
def inference(model, dataloader, device='cpu'):
    preds = []
    model.eval()
    for x_batch, _ in tqdm(dataloader):
        data = x_batch.to(device)
        with torch.inference_mode():
            output = model(data)
            preds.append(output.detach().cpu())
            
    return torch.cat(preds)

def averaging(base_model, paths):
    target_state_dict = base_model(num_classes=5).state_dict()
    for key in target_state_dict:
        target_state_dict[key].data.fill_(0.)
        
    for path in tqdm(paths):
        model = base_model(num_classes=5)
        model.load_state_dict(torch.load(path))
        state_dict = model.state_dict()
        for key in target_state_dict:
            if target_state_dict[key].data.dtype != torch.float32:
                continue
            target_state_dict[key].data += state_dict[key].data.clone() / len(paths)
            
    return target_state_dict

In [None]:
# last 5-6 checkpoints names for weight averaging
# add your checkpoints names if you want
paths = [
    'checkpoints/worthy-yogurt-79-15-0.8039_model.pt',
    'checkpoints/worthy-yogurt-79-16-0.8044_model.pt',
    'checkpoints/worthy-yogurt-79-17-0.8060_model.pt',
    'checkpoints/worthy-yogurt-79-18-0.8066_model.pt',
    'checkpoints/worthy-yogurt-79-19-0.8071_model.pt',
    'checkpoints/worthy-yogurt-79-20-0.8076_model.pt',
]

cool_model = averaging(efficientnet_b0, paths)

In [None]:
# test time augmentations
model = efficientnet_b0(num_classes=5)
model.load_state_dict(cool_model)
model.to(device)
dataloader = torch.utils.data.DataLoader(
    CustomDataset('data/sample_submission.csv', get_transforms('train')),
    batch_size=256, num_workers=8
)

preds = inference(model, dataloader, device)

tta = 40
for t in range(tta):
    preds += inference(model, dataloader, device)

In [None]:
test = pd.read_csv('data/sample_submission.csv', )
test.iloc[:, 1:] = preds / (tta + 1) # btw we can skip division
test.to_csv(f'data/submissions/your_cool_submission.csv', index=False)