In [4]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import model as eq_model

In [5]:
import numpy as np
import torch.nn as nn
import torch
from torchvision import datasets, models, transforms

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn import metrics
import gc

import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Rearrange


    
def metric(y_true, y_pred):
    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    auc = metrics.auc(fpr, tpr)
    return auc

def straightner(a):
    A = np.zeros((a[0].shape[0]*len(a)))
    start_index = 0
    end_index = 0
    for i in range(len(a)):
        start_index = i*a[0].shape[0]
        end_index = start_index+a[0].shape[0]
        A[start_index:end_index] = a[i]
    return A

def predictor(outputs):
    return np.argmax(outputs, axis = 1)

def trainer():
    model = eq_model.model(channels=3,N=4, group = "dihyderal")
    
    train_transform = transforms.Compose([transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])
    
    
    dataset_Train = datasets.ImageFolder('./Data/Train/', transform=train_transform)
    dataset_Test = datasets.ImageFolder('./Data/Test/', transform =test_transform)
    dataloader_train = torch.utils.data.DataLoader(dataset_Train, batch_size=64, shuffle=True, drop_last = True, num_workers=4, pin_memory = True)
    dataloader_test = torch.utils.data.DataLoader(dataset_Test, batch_size=64, shuffle=True, drop_last = True, num_workers=4, pin_memory = True)    
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose = True,threshold = 0.0001,patience = 3, factor = 0.5)
    
    model = model.to("cuda:2")
    


    import wandb
    wandb.login(key="cb53927c12bd57a0d943d2dedf7881cfcdcc8f09")
    wandb.init(
        project = "Equivariant",
        name = "D4"
    )

    scaler = torch.cuda.amp.GradScaler()
    #--------------------------
    wandb.watch(model, log_freq=50)
    #---------------------------
    w_intr = 50

    for epoch in range(20):
        train_loss = 0
        val_loss = 0
        train_steps = 0
        test_steps = 0
        label_list = []
        outputs_list = []
        train_auc = 0
        test_auc = 0
        model.train()
        for image, label in tqdm(dataloader_train):
            image = image.to("cuda:2")
            label = label.to("cuda:2")
            with torch.no_grad():
                image = nn.functional.pad(image, (2,1,2,1))
            #optimizer.zero_grad()
            for param in model.parameters():
                param.grad = None

            with torch.cuda.amp.autocast():
              outputs = model(image)
              loss = criterion(outputs, label.float())
            label_list.append(label.detach().cpu().numpy())
            outputs_list.append(outputs.detach().cpu().numpy())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
            train_steps += 1
            if train_steps%w_intr == 0:
                 wandb.log({"loss": loss.item()})
        with torch.no_grad():
            label_list = straightner(label_list)
            outputs_list = straightner(outputs_list)
            train_auc = metric(label_list, outputs_list) 




        #-------------------------------------------------------------------
        model.eval()
        label_list = []
        outputs_list = []
        with torch.no_grad():
            for image, label in tqdm(dataloader_test):
                image = image.to("cuda:2")
                label = label.to("cuda:2")
                image = nn.functional.pad(image, (2,1,2,1))
                outputs = model(image)
                loss = criterion(outputs, label.float())
                label_list.append(label.detach().cpu().numpy())
                outputs_list.append(outputs.detach().cpu().numpy())
                val_loss += loss.item()
                test_steps +=1
                if test_steps%w_intr == 0:
                 wandb.log({"val_loss": loss.item()})
            label_list = straightner(label_list)
            outputs_list = straightner(outputs_list)
            test_auc = metric(label_list, outputs_list)

        train_loss = train_loss/train_steps
        val_loss = val_loss/ test_steps
        
        print("----------------------------------------------------")
        print("Epoch No" , epoch)
        print("The Training loss of the epoch, ",train_loss)
        print("The Training AUC of the epoch,  %.5f"%train_auc)
        print("The validation loss of the epoch, ",val_loss)
        print("The validation AUC of the epoch, %.5f"%test_auc)
        print("----------------------------------------------------")
        PATH = f"model_Epoch_{epoch}.pt"
#         torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'scheduler': scheduler.state_dict()
#                 }, PATH)
        scheduler.step(test_auc)
        curr_lr = scheduler._last_lr[0]
        wandb.log({"Train_auc_epoch": train_auc,
                  "Epoch": epoch,
                  "Val_auc_epoch": test_auc,
                  "Train_loss_epoch": train_loss,
                  "Val_loss_epoch": val_loss,
                  "Lr": curr_lr}
                 )
        gc.collect()
    
    wandb.finish()


In [6]:
trainer()

[34m[1mwandb[0m: Currently logged in as: [33mdc250601[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/diptarko/.netrc


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:16<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 0
The Training loss of the epoch,  0.5792995386767661
The Training AUC of the epoch,  0.76646
The validation loss of the epoch,  0.5830069868729032
The validation AUC of the epoch, 0.77629
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:23<00:00,  2.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.09it/s]


----------------------------------------------------
Epoch No 1
The Training loss of the epoch,  0.5637668583756206
The Training AUC of the epoch,  0.78199
The validation loss of the epoch,  0.5625567312213197
The validation AUC of the epoch, 0.78373
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:15<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.13it/s]


----------------------------------------------------
Epoch No 2
The Training loss of the epoch,  0.5582308091993989
The Training AUC of the epoch,  0.78722
The validation loss of the epoch,  0.563340816895167
The validation AUC of the epoch, 0.78805
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:13<00:00,  2.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.13it/s]


----------------------------------------------------
Epoch No 3
The Training loss of the epoch,  0.5547424389028001
The Training AUC of the epoch,  0.79058
The validation loss of the epoch,  0.5614297122105785
The validation AUC of the epoch, 0.78966
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:17<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 4
The Training loss of the epoch,  0.5516142075945591
The Training AUC of the epoch,  0.79361
The validation loss of the epoch,  0.5609466643169009
The validation AUC of the epoch, 0.78860
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:17<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 5
The Training loss of the epoch,  0.5481977886338344
The Training AUC of the epoch,  0.79673
The validation loss of the epoch,  0.5707538995934629
The validation AUC of the epoch, 0.79175
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:19<00:00,  2.56it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 6
The Training loss of the epoch,  0.5449660068955915
The Training AUC of the epoch,  0.79980
The validation loss of the epoch,  0.554100673705682
The validation AUC of the epoch, 0.79204
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:22<00:00,  2.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 7
The Training loss of the epoch,  0.5426743176476708
The Training AUC of the epoch,  0.80183
The validation loss of the epoch,  0.5563405058164707
The validation AUC of the epoch, 0.79544
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:13<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 8
The Training loss of the epoch,  0.5383364344465321
The Training AUC of the epoch,  0.80563
The validation loss of the epoch,  0.5538546557399048
The validation AUC of the epoch, 0.79231
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:15<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 9
The Training loss of the epoch,  0.5335687510412315
The Training AUC of the epoch,  0.80979
The validation loss of the epoch,  0.5522924605457262
The validation AUC of the epoch, 0.79400
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:12<00:00,  2.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 10
The Training loss of the epoch,  0.5286075518049043
The Training AUC of the epoch,  0.81412
The validation loss of the epoch,  0.5615879930298904
The validation AUC of the epoch, 0.79072
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:16<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.12it/s]


----------------------------------------------------
Epoch No 11
The Training loss of the epoch,  0.5205660484303003
The Training AUC of the epoch,  0.82078
The validation loss of the epoch,  0.5655126444909765
The validation AUC of the epoch, 0.78772
----------------------------------------------------
Epoch 00012: reducing learning rate of group 0 to 5.0000e-04.


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:13<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 12
The Training loss of the epoch,  0.4945028548096788
The Training AUC of the epoch,  0.84106
The validation loss of the epoch,  0.5794312290076552
The validation AUC of the epoch, 0.77749
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:18<00:00,  2.56it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.11it/s]


----------------------------------------------------
Epoch No 13
The Training loss of the epoch,  0.4718054872134636
The Training AUC of the epoch,  0.85705
The validation loss of the epoch,  0.5983985052026551
The validation AUC of the epoch, 0.77130
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:14<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 14
The Training loss of the epoch,  0.4458098532653403
The Training AUC of the epoch,  0.87392
The validation loss of the epoch,  0.6273632495567716
The validation AUC of the epoch, 0.75856
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:21<00:00,  2.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.05it/s]


----------------------------------------------------
Epoch No 15
The Training loss of the epoch,  0.4134035998206029
The Training AUC of the epoch,  0.89278
The validation loss of the epoch,  0.6641519918523986
The validation AUC of the epoch, 0.75754
----------------------------------------------------
Epoch 00016: reducing learning rate of group 0 to 2.5000e-04.


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:20<00:00,  2.56it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 16
The Training loss of the epoch,  0.3438774799735382
The Training AUC of the epoch,  0.92709
The validation loss of the epoch,  0.7690779353010243
The validation AUC of the epoch, 0.74029
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:15<00:00,  2.58it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 17
The Training loss of the epoch,  0.3040646186419602
The Training AUC of the epoch,  0.94313
The validation loss of the epoch,  0.8305893747285865
The validation AUC of the epoch, 0.72747
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:21<00:00,  2.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.10it/s]


----------------------------------------------------
Epoch No 18
The Training loss of the epoch,  0.265648663557809
The Training AUC of the epoch,  0.95650
The validation loss of the epoch,  0.9317625254050068
The validation AUC of the epoch, 0.71979
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [11:15<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [01:01<00:00,  7.08it/s]


----------------------------------------------------
Epoch No 19
The Training loss of the epoch,  0.22721577738037055
The Training AUC of the epoch,  0.96798
The validation loss of the epoch,  1.0182523506811296
The validation AUC of the epoch, 0.71367
----------------------------------------------------
Epoch 00020: reducing learning rate of group 0 to 1.2500e-04.


0,1
Epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
Lr,███████████▄▄▄▄▂▂▂▂▁
Train_auc_epoch,▁▂▂▂▂▂▂▂▂▃▃▃▄▄▅▅▇▇██
Train_loss_epoch,████▇▇▇▇▇▇▇▇▆▆▅▅▃▃▂▁
Val_auc_epoch,▆▇▇█▇██████▇▆▆▅▅▃▂▂▁
Val_loss_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▄▅▇█
loss,█▆▇▅▅▅▇▇▇▆▇▆▅▆▇▆▅▆▆▇▅▅▅▇▃▆▆▅▆▅▆▅▂▃▂▃▁▃▂▂
val_loss,▂▃▂▂▃▂▂▄▂▂▂▄▂▂▂▂▄▁▃▂▂▂▂▂▃▃▃▄▅▁▄▃▄▄▆▅▆█▅█

0,1
Epoch,19.0
Lr,0.00013
Train_auc_epoch,0.96798
Train_loss_epoch,0.22722
Val_auc_epoch,0.71367
Val_loss_epoch,1.01825
loss,0.23553
val_loss,1.13875
