In [4]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import model as eq_model

In [5]:
import numpy as np
import torch.nn as nn
import torch
from torchvision import datasets, models, transforms

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn import metrics
import gc

import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Rearrange


    
def metric(y_true, y_pred):
    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    auc = metrics.auc(fpr, tpr)
    return auc

def straightner(a):
    A = np.zeros((a[0].shape[0]*len(a)))
    start_index = 0
    end_index = 0
    for i in range(len(a)):
        start_index = i*a[0].shape[0]
        end_index = start_index+a[0].shape[0]
        A[start_index:end_index] = a[i]
    return A

def predictor(outputs):
    return np.argmax(outputs, axis = 1)

def trainer():
    model = eq_model.model(channels=3,N=8, group = "cyclic")
    
    train_transform = transforms.Compose([transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])
    
    
    dataset_Train = datasets.ImageFolder('./Data/Train/', transform=train_transform)
    dataset_Test = datasets.ImageFolder('./Data/Test/', transform =test_transform)
    dataloader_train = torch.utils.data.DataLoader(dataset_Train, batch_size=64, shuffle=True, drop_last = True, num_workers=4, pin_memory = True)
    dataloader_test = torch.utils.data.DataLoader(dataset_Test, batch_size=64, shuffle=True, drop_last = True, num_workers=4, pin_memory = True)    
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose = True,threshold = 0.0001,patience = 3, factor = 0.5)
    
    model = model.to("cuda:1")
    


    import wandb
    wandb.login(key="cb53927c12bd57a0d943d2dedf7881cfcdcc8f09")
    wandb.init(
        project = "Equivariant",
        name = "C8"
    )

    scaler = torch.cuda.amp.GradScaler()
    #--------------------------
    wandb.watch(model, log_freq=50)
    #---------------------------
    w_intr = 50

    for epoch in range(20):
        train_loss = 0
        val_loss = 0
        train_steps = 0
        test_steps = 0
        label_list = []
        outputs_list = []
        train_auc = 0
        test_auc = 0
        model.train()
        for image, label in tqdm(dataloader_train):
            image = image.to("cuda:1")
            label = label.to("cuda:1")
            with torch.no_grad():
                image = nn.functional.pad(image, (2,1,2,1))
            #optimizer.zero_grad()
            for param in model.parameters():
                param.grad = None

            with torch.cuda.amp.autocast():
              outputs = model(image)
              loss = criterion(outputs, label.float())
            label_list.append(label.detach().cpu().numpy())
            outputs_list.append(outputs.detach().cpu().numpy())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
            train_steps += 1
            if train_steps%w_intr == 0:
                 wandb.log({"loss": loss.item()})
        with torch.no_grad():
            label_list = straightner(label_list)
            outputs_list = straightner(outputs_list)
            train_auc = metric(label_list, outputs_list) 




        #-------------------------------------------------------------------
        model.eval()
        label_list = []
        outputs_list = []
        with torch.no_grad():
            for image, label in tqdm(dataloader_test):
                image = image.to("cuda:1")
                label = label.to("cuda:1")
                image = nn.functional.pad(image, (2,1,2,1))
                outputs = model(image)
                loss = criterion(outputs, label.float())
                label_list.append(label.detach().cpu().numpy())
                outputs_list.append(outputs.detach().cpu().numpy())
                val_loss += loss.item()
                test_steps +=1
                if test_steps%w_intr == 0:
                 wandb.log({"val_loss": loss.item()})
            label_list = straightner(label_list)
            outputs_list = straightner(outputs_list)
            test_auc = metric(label_list, outputs_list)

        train_loss = train_loss/train_steps
        val_loss = val_loss/ test_steps
        
        print("----------------------------------------------------")
        print("Epoch No" , epoch)
        print("The Training loss of the epoch, ",train_loss)
        print("The Training AUC of the epoch,  %.5f"%train_auc)
        print("The validation loss of the epoch, ",val_loss)
        print("The validation AUC of the epoch, %.5f"%test_auc)
        print("----------------------------------------------------")
        PATH = f"model_Epoch_{epoch}.pt"
#         torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'scheduler': scheduler.state_dict()
#                 }, PATH)
        scheduler.step(test_auc)
        curr_lr = scheduler._last_lr[0]
        wandb.log({"Train_auc_epoch": train_auc,
                  "Epoch": epoch,
                  "Val_auc_epoch": test_auc,
                  "Train_loss_epoch": train_loss,
                  "Val_loss_epoch": val_loss,
                  "Lr": curr_lr}
                 )
        gc.collect()
    
    wandb.finish()


In [6]:
trainer()

[34m[1mwandb[0m: Currently logged in as: [33mdc250601[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/diptarko/.netrc


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:09<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.09it/s]


----------------------------------------------------
Epoch No 0
The Training loss of the epoch,  0.5769874441041344
The Training AUC of the epoch,  0.76896
The validation loss of the epoch,  0.5777225972592146
The validation AUC of the epoch, 0.77869
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:16<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.08it/s]


----------------------------------------------------
Epoch No 1
The Training loss of the epoch,  0.5637606028338958
The Training AUC of the epoch,  0.78194
The validation loss of the epoch,  0.5769779804794267
The validation AUC of the epoch, 0.78502
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:13<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:00<00:00,  7.13it/s]


----------------------------------------------------
Epoch No 2
The Training loss of the epoch,  0.5587452794457304
The Training AUC of the epoch,  0.78676
The validation loss of the epoch,  0.5682234488684555
The validation AUC of the epoch, 0.78650
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:09<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.13it/s]


----------------------------------------------------
Epoch No 3
The Training loss of the epoch,  0.5544563652968955
The Training AUC of the epoch,  0.79097
The validation loss of the epoch,  0.5547332948651807
The validation AUC of the epoch, 0.79140
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:08<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:00<00:00,  7.13it/s]


----------------------------------------------------
Epoch No 4
The Training loss of the epoch,  0.5514984331425579
The Training AUC of the epoch,  0.79369
The validation loss of the epoch,  0.5608880225954385
The validation AUC of the epoch, 0.79259
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:08<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.13it/s]


----------------------------------------------------
Epoch No 5
The Training loss of the epoch,  0.5475334376096725
The Training AUC of the epoch,  0.79737
The validation loss of the epoch,  0.566116482430491
The validation AUC of the epoch, 0.79217
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:08<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 6
The Training loss of the epoch,  0.5439946938862745
The Training AUC of the epoch,  0.80059
The validation loss of the epoch,  0.5537958133494717
The validation AUC of the epoch, 0.79524
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:09<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 7
The Training loss of the epoch,  0.5391037986196321
The Training AUC of the epoch,  0.80506
The validation loss of the epoch,  0.5705277750546904
The validation AUC of the epoch, 0.79359
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:08<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 8
The Training loss of the epoch,  0.5340814914168983
The Training AUC of the epoch,  0.80941
The validation loss of the epoch,  0.5681234809173935
The validation AUC of the epoch, 0.78888
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:08<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 9
The Training loss of the epoch,  0.5256443499833688
The Training AUC of the epoch,  0.81652
The validation loss of the epoch,  0.5784637078471567
The validation AUC of the epoch, 0.78996
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:08<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 10
The Training loss of the epoch,  0.5141777230576537
The Training AUC of the epoch,  0.82599
The validation loss of the epoch,  0.58060413770292
The validation AUC of the epoch, 0.77505
----------------------------------------------------
Epoch 00011: reducing learning rate of group 0 to 5.0000e-04.


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:09<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 11
The Training loss of the epoch,  0.47775654088834235
The Training AUC of the epoch,  0.85311
The validation loss of the epoch,  0.6268453497311165
The validation AUC of the epoch, 0.75961
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:09<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 12
The Training loss of the epoch,  0.44244907789531795
The Training AUC of the epoch,  0.87601
The validation loss of the epoch,  0.6413778079652238
The validation AUC of the epoch, 0.75823
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:10<00:00,  2.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 13
The Training loss of the epoch,  0.403695464305494
The Training AUC of the epoch,  0.89814
The validation loss of the epoch,  0.6646416096851744
The validation AUC of the epoch, 0.75113
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:11<00:00,  2.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 14
The Training loss of the epoch,  0.3586579474909552
The Training AUC of the epoch,  0.92034
The validation loss of the epoch,  0.7751011482600508
The validation AUC of the epoch, 0.73698
----------------------------------------------------
Epoch 00015: reducing learning rate of group 0 to 2.5000e-04.


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:16<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.04it/s]


----------------------------------------------------
Epoch No 15
The Training loss of the epoch,  0.26828875575942557
The Training AUC of the epoch,  0.95521
The validation loss of the epoch,  0.8766856230538467
The validation AUC of the epoch, 0.71933
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:21<00:00,  2.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 16
The Training loss of the epoch,  0.21755408948694152
The Training AUC of the epoch,  0.96997
The validation loss of the epoch,  0.9854201438098118
The validation AUC of the epoch, 0.72323
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:14<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.09it/s]


----------------------------------------------------
Epoch No 17
The Training loss of the epoch,  0.1746228604866513
The Training AUC of the epoch,  0.98017
The validation loss of the epoch,  1.1756191618141087
The validation AUC of the epoch, 0.70791
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:14<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.09it/s]


----------------------------------------------------
Epoch No 18
The Training loss of the epoch,  0.13988634393304245
The Training AUC of the epoch,  0.98695
The validation loss of the epoch,  1.2521948792468542
The validation AUC of the epoch, 0.70322
----------------------------------------------------
Epoch 00019: reducing learning rate of group 0 to 1.2500e-04.


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:12<00:00,  2.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 19
The Training loss of the epoch,  0.0891926135279067
The Training AUC of the epoch,  0.99402
The validation loss of the epoch,  1.4763251939039121
The validation AUC of the epoch, 0.70019
----------------------------------------------------


0,1
Epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
Lr,██████████▄▄▄▄▂▂▂▂▁▁
Train_auc_epoch,▁▁▂▂▂▂▂▂▂▂▃▄▄▅▆▇▇███
Train_loss_epoch,███████▇▇▇▇▇▆▆▅▄▃▂▂▁
Val_auc_epoch,▇▇▇███████▇▅▅▅▄▂▃▂▁▁
Val_loss_epoch,▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▆▆█
loss,▇▆▆▇█▇▇█▇▇▇▇█▇▆▆▆▆▆▇▆▅▅▆▄▆▄▄▃▄▄▃▂▃▂▃▃▃▁▁
val_loss,▂▂▃▁▁▁▁▁▂▂▂▂▂▂▂▂▁▁▂▂▃▂▂▂▂▁▂▂▃▂▅▆▄▄▆▅▆█▅█

0,1
Epoch,19.0
Lr,0.00013
Train_auc_epoch,0.99402
Train_loss_epoch,0.08919
Val_auc_epoch,0.70019
Val_loss_epoch,1.47633
loss,0.04378
val_loss,1.36188
