In [None]:
from linear import Linear, BasicLinear
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
from tqdm import tqdm
import sklearn.metrics as metrics
from sklearn.ensemble import RandomForestClassifier
from configparser import Interpolation
import torch
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from torch.utils.data import DataLoader

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
class Dataset_Classification(Dataset):
    def __init__(self, csv):
        self.csv = csv

    def __len__(self):
        return len(self.csv['image'])
    
    def __getitem__(self, idx):
        input = torch.Tensor(self.csv['image'][idx])
        label = torch.Tensor([self.csv['label'][idx]])
        return input, label

In [None]:
train = np.load("./train.npy", allow_pickle=True).item()
test = np.load("./test.npy", allow_pickle=True).item()

In [None]:
train_dataset = Dataset_Classification(csv=train)
test_dataset = Dataset_Classification(csv=test)

trainloader = DataLoader(train_dataset, batch_size=16,
                             num_workers=0, pin_memory=True, shuffle=True)   
testloader = DataLoader(test_dataset, batch_size=16,
                            num_workers=0, pin_memory=True, shuffle=False)  

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
type = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
model = Linear(in_channels=2048, embedding_channels=1024, hidden_channels=None, classes=1, depth=4, dim_reduction=True)
#model = BasicLinear(in_channels=2048, hidden_channels=2048, classes=1)
optimizer = optim.SGD(model.parameters(), lr=1.25e-3, momentum=1e-6)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch:0.975**epoch)
criterion = nn.BCEWithLogitsLoss()

model = model.to(device)
criterion = criterion.to(device)
epochs = 200

In [None]:
# AUC ROC
best_ar = 0
# AUC PRC
best_ap = 0

In [None]:
file = open(f'train_log.txt', 'a')
for epoch in range(0, epochs):
    for idx, data in enumerate(trainloader):
        input, label = data
        input = Variable(input.type(type))
        label = Variable(label.type(type))
        output = model(input)
        optimizer.zero_grad()
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        if idx % (10) == 0:
                tqdm.write(f'Epoch : {epoch} Iter : {idx}/{len(trainloader)} '
                        f'Loss : {loss :.4f} ', file=file)
    scheduler.step()
    y_test = []
    y_pred = []
    with torch.no_grad():
        for idx, data in enumerate(testloader):
            input, label = data
            input = Variable(input.type(type))
            label = Variable(label.type(type))
            y_test.extend(label.detach().cpu().numpy())
            output = model(input)
            y_pred.extend(output.detach().cpu().numpy())
    y_test = np.array(y_test)
    y_pred = np.array(y_pred)
    precision, recall, _ = metrics.precision_recall_curve(y_test, y_pred, pos_label=1)
    ap = metrics.auc(recall, precision)
    fpr_roc, tpr_roc, _ = metrics.roc_curve(y_test, y_pred, pos_label=1)
    ar = metrics.auc(fpr_roc, tpr_roc)
    if best_ar < ar:
        print(f'epoch: {epoch} best_ar:{ar}')
        best_ar = ar
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_ar': best_ar
        }, "./best_ar_linear.pth.tar")
    if best_ap < ap:
        best_ap = ap
        print(f'epoch: {epoch} best_ap:{ap}')
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_ap': best_ap
        }, "./best_ap_linear.pth.tar")