In [None]:
import pandas as pd
import torch
import torch.optim as optim
from torch import nn

from dataset import snn_batch
from model.snn import SNN

In [None]:
cv_list = [1, 2, 3, 4, 5]

In [None]:
for cv in cv_list:
    fsl_df = pd.read_csv('data/fsl2_cv{}_10000_2.csv'.format(cv))

    backbone_list = ['transformer', 'resnet', 'lstm']
    backbone = backbone_list[0]
    train_index = '100_cv{}'.format(cv)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = SNN(1302, 701, input_dim=512, feature_dim=256, backbone=backbone, r_num_layers=1, t_num_layers=1,
                l_num_layers=1)
    model = model.to(device)
    criteria = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.5, 0.999))

    epochs = 10000
    batch_size = 1024
    early_stop = 2000
    min_loss = 99
    best_epoch = 0

    model.train()
    for epoch in range(epochs):
        sample1, sample2, labels = snn_batch(fsl_df, batch_size=batch_size, pos_neg_ratio=0.5)
        sample1, sample2, labels = sample1.to(device), sample2.to(device), labels.to(device)
        outputs = model(sample1.float(), sample2.float())
        loss = criteria(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if loss.item() < min_loss:
            min_loss = loss.item()
            best_epoch = epoch
            torch.save(model.state_dict(), 'checkpoints/fsl2_{}_{}_1024_2.pth'.format(backbone, train_index))

        if epoch - best_epoch > early_stop:
            break

        if epoch % 5 == 0:
            print('cv: {}, epoch: {}/{}, loss: {:.8f}'.format(cv, epoch, epochs, loss))
            print('cv: {}, Best epoch: {}, Min loss: {:.8f}'.format(cv, best_epoch, min_loss))

        if loss.item() < 0.001:
            print('cv: {}, epoch: {}/{}, loss: {:.8f}'.format(cv, epoch, epochs, loss))
            print('cv: {}, Best epoch: {}, Min loss: {:.8f}'.format(cv, best_epoch, min_loss))
            break

    print("Finish Training.")