In [1]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

from torch.cuda.amp import autocast
from ptflops import get_model_complexity_info

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2


from sklearn.metrics import f1_score,roc_auc_score,precision_score,recall_score
from transformers import get_cosine_schedule_with_warmup

import timm


# Utils
import joblib
from tqdm.notebook import tqdm
from collections import defaultdict

import warnings
warnings.filterwarnings("ignore")

In [2]:
data_dir='/home/fateplsf/hw/multi_class/hw1/images'
hw_dir="/home/fateplsf/hw/multi_class/hw1"

os.makedirs(hw_dir+"/model",exist_ok=True)



In [3]:
CONFIG = {
    "n_accumulate": 1,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}

In [4]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed()

In [5]:
train_df = pd.read_csv(f'{data_dir}/train.txt', delim_whitespace=True,header=None,names=['path', 'label']) 
valid_df = pd.read_csv(f'{data_dir}/val.txt', delim_whitespace=True,header=None,names=['path', 'label']) 
test_df = pd.read_csv(f'{data_dir}/test.txt', delim_whitespace=True,header=None,names=['path', 'label']) 

In [6]:
class ImageDataset(Dataset):
    def __init__(self, df, augment=None):
        self.length = len(df)
        self.df = df
        self.augment = augment

    def __len__(self):
        return self.length
    def __getitem__(self,index):
        d = self.df.iloc[index]
        img_path=img_path=data_dir+"/"+d["path"]
        img=cv2.imread(img_path)
        img = img.astype(np.float32)/255
        label=d["label"]
        
        if self.augment is not None: 
            img = self.augment(image=img)["image"]
            


        return {
            'image': img,
            'label': torch.tensor(label, dtype=torch.long)
            
        }

In [7]:
data_transforms = {
    "train": A.Compose([
        A.Resize(384, 384),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        
        A.ShiftScaleRotate(shift_limit=0.1, 
                           scale_limit=0.1, 
                           rotate_limit=30, 
                           
                           p=0.5),
        A.HueSaturationValue(
                hue_shift_limit=0.1, 
                sat_shift_limit=0.1, 
                val_shift_limit=0.1, 
                p=0.5 
            ),
        A.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1),  
                p=0.5 
            ),
        
#         A.Normalize(),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(384, 384),

#         A.Normalize(),
        ToTensorV2()], p=1.)
}

In [8]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)

        batch_size = images.size(0)
        
        with amp.autocast(enabled = True):
            outputs = model(images)
            
            loss = criterion(outputs, labels)
            
            loss = loss / CONFIG['n_accumulate']
            
        scaler.scale(loss).backward()
        
        if (step + 1) % CONFIG['n_accumulate'] == 0:

            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()



            if scheduler is not None:
                scheduler.step()
        
        
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return epoch_loss

@torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    true_y = []
    pred_y = []
    pred_y_class=[]
    
    for step, data in bar:
        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)

        outputs = model(images)
        loss = criterion(outputs, labels)

        
        true_y.append(labels.cpu().numpy())
        
        tmp_pred=torch.nn.Softmax(dim=1)(outputs)
        pred_y.append(tmp_pred.cpu().numpy())
   
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])   
        
    

    true_y=np.concatenate(true_y)
    pred_y=np.concatenate(pred_y)
 
    
    gc.collect()
    

    auc = roc_auc_score(true_y,pred_y,multi_class='ovr', average='macro')
 


    return epoch_loss,auc,true_y,pred_y


def run_training(model, optimizer, scheduler,train_loader,valid_loader,device,num_epochs):

    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    
#     model=model.cuda()
    best_true_y,best_pred_y = [] ,[]
    start = time.time()
    
#     best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_auc = 0

    history = defaultdict(list)
    gc.collect()
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_epoch_loss = train_one_epoch(model, optimizer, scheduler, 
                                           train_loader,device,epoch=epoch)
        
        val_epoch_loss,auc,true_y,pred_y= valid_one_epoch(model, valid_loader,device,epoch=epoch)
#         print("Epoch:",epoch)
        print(val_epoch_loss,auc)
    
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(val_epoch_loss)
        history['Valid Auc'].append(auc)

        if (auc >= best_epoch_auc) : #and (epoch >3):
            print(f"Validation Auc Improved ({best_epoch_auc} ---> {auc})")
            best_epoch_auc = auc
            
            PATH = hw_dir+'/model/'+f"job_{job}_model"+".bin"
            torch.save(model.state_dict(), PATH)
            best_true_y,best_pred_y = true_y,pred_y
            

            
            

    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))

    

    return history,best_true_y,best_pred_y
    #return model, history

In [9]:
@torch.inference_mode()
def get_score(model, dataloader, device):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader), disable=True)
    true_y = []
    pred_y = []
    pred_y_class=[]
    
    for step, data in bar:
        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)

        outputs = model(images)
        loss = criterion(outputs, labels)

        
        true_y.append(labels.cpu().numpy())
        
        tmp_pred=torch.nn.Softmax(dim=1)(outputs)
        pred_y.append(tmp_pred.cpu().numpy())
   
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
#         bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss,
#                         LR=optimizer.param_groups[0]['lr'])   
        
    

    true_y=np.concatenate(true_y)
    pred_y=np.concatenate(pred_y)
 
    
    gc.collect()
    

    auc = roc_auc_score(true_y,pred_y,multi_class='ovr', average='macro')
    pred_labels = np.argmax(pred_y, axis=1)
    f1 = f1_score(true_y, pred_labels, average='macro')
    precision = precision_score(true_y, pred_labels, average='macro')
    recall = recall_score(true_y, pred_labels, average='macro')


    return auc,f1,precision,recall

In [10]:
    
    
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(SimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.downsample2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(128)
        )

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.downsample3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(256)
        )

        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.downsample4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(512)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        identity = x
        x = self.conv2(x)
        x = self.bn2(x)
        if identity.shape != x.shape:
            identity = self.downsample2(identity)
        x += identity
        x = self.relu(x)

        identity = x
        x = self.conv3(x)
        x = self.bn3(x)
        if identity.shape != x.shape:
            identity = self.downsample3(identity)
        x += identity
        x = self.relu(x)

        identity = x
        x = self.conv4(x)
        x = self.bn4(x)
        if identity.shape != x.shape:
            identity = self.downsample4(identity)
        x += identity
        x = self.relu(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [11]:
model5 = timm.create_model('resnet34', pretrained=False,in_chans=3, num_classes=50)
model6=SimpleResNet(num_classes=50)

In [12]:
macs, params = get_model_complexity_info(model5, (3, 384, 384), as_strings=True, print_per_layer_stat=False)
print(f"FLOPs: {macs}, Parameters: {params}")

FLOPs: 10.81 GMac, Parameters: 21.31 M


In [13]:
macs, params = get_model_complexity_info(model6, (3, 384, 384), as_strings=True, print_per_layer_stat=False)
print(f"FLOPs: {macs}, Parameters: {params}")

FLOPs: 2.63 GMac, Parameters: 1.76 M


In [14]:
train_dataset = ImageDataset(train_df,augment=data_transforms["train"])
valid_dataset = ImageDataset(valid_df,augment=data_transforms["valid"])
train_loader = DataLoader(train_dataset, batch_size=128, #128
                          num_workers=16, shuffle=True, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=128, 
                          num_workers=16, shuffle=False, pin_memory=True)

In [15]:
loss = nn.CrossEntropyLoss()
def criterion(outputs, labels):
    return loss(outputs, labels)

In [16]:
job=5
is_amp = True
scaler = amp.GradScaler(enabled = is_amp)

model5 = model5.to(CONFIG['device'])

In [17]:

optimizer = optim.AdamW(model5.parameters(), lr=1.5e-3)
num_train_steps = int(len(train_loader) * 24)
num_warmup_steps = int(num_train_steps / 10)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

history,true_y,pred_y= run_training(model5, optimizer, scheduler,train_loader,valid_loader,device=CONFIG['device'],num_epochs=24)

[INFO] Using GPU: NVIDIA GeForce RTX 3090



  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.6877703401777477 0.7558427815570674
Validation Auc Improved (0 ---> 0.7558427815570674)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.375148867501153 0.8350718065003777
Validation Auc Improved (0.7558427815570674 ---> 0.8350718065003777)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.4627138386832343 0.9123456790123455
Validation Auc Improved (0.8350718065003777 ---> 0.9123456790123455)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

4.734038413365682 0.7816276140085663


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.1649452304840087 0.9395968757873518
Validation Auc Improved (0.9123456790123455 ---> 0.9395968757873518)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.871200507481893 0.9567548500881834
Validation Auc Improved (0.9395968757873518 ---> 0.9567548500881834)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.8190085363388062 0.9592441421012851
Validation Auc Improved (0.9567548500881834 ---> 0.9592441421012851)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.6538842791981168 0.9627865961199294
Validation Auc Improved (0.9592441421012851 ---> 0.9627865961199294)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.8378812795215183 0.9569009826152681


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.4614672078026665 0.975273368606702
Validation Auc Improved (0.9627865961199294 ---> 0.975273368606702)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.3940136249860127 0.9764877802973041
Validation Auc Improved (0.975273368606702 ---> 0.9764877802973041)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.1541874663035074 0.9832199546485261
Validation Auc Improved (0.9764877802973041 ---> 0.9832199546485261)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.1977171240912543 0.980251952632905


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.019126602543725 0.9867069790879315
Validation Auc Improved (0.9832199546485261 ---> 0.9867069790879315)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.2179313418600295 0.9819098009574201


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.9662466830677456 0.9879062736205594
Validation Auc Improved (0.9867069790879315 ---> 0.9879062736205594)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.0124453184339735 0.9867523305618543


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.9190005977948507 0.98843537414966
Validation Auc Improved (0.9879062736205594 ---> 0.98843537414966)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.9664870540301005 0.9874275636180398


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.923258945412106 0.9881078357268833


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.899645623366038 0.9888485764676241
Validation Auc Improved (0.98843537414966 ---> 0.9888485764676241)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.9020584901173909 0.9888435374149658


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.8887472907702129 0.9892517006802721
Validation Auc Improved (0.9888485764676241 ---> 0.9892517006802721)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0.8917235721482171 0.9893272864701436
Validation Auc Improved (0.9892517006802721 ---> 0.9893272864701436)
Training complete in 0h 57m 16s


In [18]:
model5.load_state_dict(torch.load(f"/home/fateplsf/hw/multi_class/hw1/model/job_{job}_model.bin"))

<All keys matched successfully>

In [19]:
test_dataset = ImageDataset(test_df,augment=data_transforms["valid"])
test_loader = DataLoader(test_dataset, batch_size=64, 
                          num_workers=8, shuffle=False, pin_memory=True)

In [20]:
auc, f1, precision, recall = get_score(model5, valid_loader, device=CONFIG['device'])
print(f"AUC = {auc:.4f}, F1 = {f1:.4f}, Precision = {precision:.4f}, Recall = {recall:.4f}")

AUC = 0.9893, F1 = 0.7286, Precision = 0.7379, Recall = 0.7311


In [21]:
auc, f1, precision, recall = get_score(model5, test_loader, device=CONFIG['device'])
print(f"AUC = {auc:.4f}, F1 = {f1:.4f}, Precision = {precision:.4f}, Recall = {recall:.4f}")

AUC = 0.9905, F1 = 0.7478, Precision = 0.7623, Recall = 0.7489


In [22]:
job=6
is_amp = True
scaler = amp.GradScaler(enabled = is_amp)

model6 = model6.to(CONFIG['device'])

In [23]:

optimizer = optim.AdamW(model6.parameters(), lr=1.5e-3)
num_train_steps = int(len(train_loader) * 24)
num_warmup_steps = int(num_train_steps / 10)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

history,true_y,pred_y= run_training(model6, optimizer, scheduler,train_loader,valid_loader,device=CONFIG['device'],num_epochs=24)

[INFO] Using GPU: NVIDIA GeForce RTX 3090



  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.7142390463087294 0.7436381960191484
Validation Auc Improved (0 ---> 0.7436381960191484)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

5.353985646565755 0.7011337868480725


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

4.023753960927327 0.7606198034769462
Validation Auc Improved (0.7436381960191484 ---> 0.7606198034769462)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.3399415916866726 0.8317712270093223
Validation Auc Improved (0.7606198034769462 ---> 0.8317712270093223)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

4.979291085137262 0.7161552028218696


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.8063888284895153 0.8829176114890401
Validation Auc Improved (0.8317712270093223 ---> 0.8829176114890401)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.901869700749715 0.8781657848324516


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.3063923337724472 0.8539027462836987


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.6477448172039457 0.90302343159486
Validation Auc Improved (0.8829176114890401 ---> 0.90302343159486)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.0248282941182456 0.8783673469387756


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.9248945501115586 0.8883597883597883


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.341019369761149 0.8667271352985638


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

2.2538141176435684 0.929498614260519
Validation Auc Improved (0.90302343159486 ---> 0.929498614260519)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

3.727662158542209 0.8625699168556312


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.9714113876554702 0.9466112370874274
Validation Auc Improved (0.929498614260519 ---> 0.9466112370874274)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.9352951791551378 0.9476643990929705
Validation Auc Improved (0.9466112370874274 ---> 0.9476643990929705)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.8177501175138686 0.954023683547493
Validation Auc Improved (0.9476643990929705 ---> 0.954023683547493)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.8512711487876043 0.9516301335348955


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.751956123246087 0.9559939531368102
Validation Auc Improved (0.954023683547493 ---> 0.9559939531368102)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.784689745903015 0.9529503653313178


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.6977239190207587 0.9589821113630638
Validation Auc Improved (0.9559939531368102 ---> 0.9589821113630638)


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.7105798080232408 0.9572285210380449


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.6885487667719523 0.9579339884101791


  0%|          | 0/494 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

1.6881379991107517 0.9580952380952381
Training complete in 0h 51m 51s


In [24]:
model6.load_state_dict(torch.load(f"/home/fateplsf/hw/multi_class/hw1/model/job_{job}_model.bin"))

<All keys matched successfully>

In [25]:
auc, f1, precision, recall = get_score(model6, valid_loader, device=CONFIG['device'])
print(f"AUC = {auc:.4f}, F1 = {f1:.4f}, Precision = {precision:.4f}, Recall = {recall:.4f}")

AUC = 0.9590, F1 = 0.5033, Precision = 0.5283, Recall = 0.5111


In [26]:
auc, f1, precision, recall = get_score(model6, test_loader, device=CONFIG['device'])
print(f"AUC = {auc:.4f}, F1 = {f1:.4f}, Precision = {precision:.4f}, Recall = {recall:.4f}")

AUC = 0.9593, F1 = 0.5144, Precision = 0.5320, Recall = 0.5222
