In [None]:
import os
import cv2
import json
import torch
import numpy as np
import albumentations as A
import torch.nn.functional as F
import torch.optim as optim

from utils import set_seed, custom_combine_collate_fn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

from dataset import *
from loss import *

In [None]:
torch.cuda.empty_cache()
set_seed(31)
epochs = 50

In [None]:
model = torch.load(
    "/opt/ml/level3_cv_finalproject-cv-01/model/save/Tresnet_m_ml_decoder_recipy_latest.pth"
)

criterion = AsymmetricLoss(
    gamma_neg=4,
    gamma_pos=0,
    clip=0.05,
    eps=1e-8,
    disable_torch_grad_focal_loss=True,
)
optimizer = optim.AdamW(params=model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)


In [None]:

tf = A.Compose(
    [
        A.Resize(224, 224),
        A.Normalize(),
    ]
)
dataset = CustomCombineWeekClassDataset(
    "/opt/ml/level3_cv_finalproject-cv-01/model/data/test_data",
    "/opt/ml/level3_cv_finalproject-cv-01/model/data/test_data_json",
    tf,
)
train_loader = DataLoader(
        dataset=dataset,
        batch_size=256,
        shuffle=True,
        num_workers=8,
        collate_fn=custom_combine_collate_fn,
        drop_last=False,
    )

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=True)
model.cuda()
for epoch in range(epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = (
            torch.tensor(images).cuda(),
            torch.tensor(labels).cuda(),
        )
        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        print(f"{epoch}Epochs\tLoss: {round(loss.item(),4)}, ")
output_path = os.path.join(
    "/opt/ml/level3_cv_finalproject-cv-01/model/save",
    "Tresnet_m_ml_decoder_recipy_latest.pth"
)
torch.save(model, output_path)