In [4]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import model as eq_model

In [5]:
import numpy as np
import torch.nn as nn
import torch
from torchvision import datasets, models, transforms

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn import metrics
import gc

import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Rearrange


    
def metric(y_true, y_pred):
    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    auc = metrics.auc(fpr, tpr)
    return auc

def straightner(a):
    A = np.zeros((a[0].shape[0]*len(a)))
    start_index = 0
    end_index = 0
    for i in range(len(a)):
        start_index = i*a[0].shape[0]
        end_index = start_index+a[0].shape[0]
        A[start_index:end_index] = a[i]
    return A

def predictor(outputs):
    return np.argmax(outputs, axis = 1)

def trainer():
    model = eq_model.model(channels=3,N=8, group = "dihyderal")
    
    train_transform = transforms.Compose([transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])
    
    
    dataset_Train = datasets.ImageFolder('./Data/Train/', transform=train_transform)
    dataset_Test = datasets.ImageFolder('./Data/Test/', transform =test_transform)
    dataloader_train = torch.utils.data.DataLoader(dataset_Train, batch_size=64, shuffle=True, drop_last = True, num_workers=4, pin_memory = True)
    dataloader_test = torch.utils.data.DataLoader(dataset_Test, batch_size=64, shuffle=True, drop_last = True, num_workers=4, pin_memory = True)    
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose = True,threshold = 0.0001,patience = 3, factor = 0.5)
    
    model = model.to("cuda:3")
    


    import wandb
    wandb.login(key="cb53927c12bd57a0d943d2dedf7881cfcdcc8f09")
    wandb.init(
        project = "Equivariant",
        name = "D8"
    )

    scaler = torch.cuda.amp.GradScaler()
    #--------------------------
    wandb.watch(model, log_freq=50)
    #---------------------------
    w_intr = 50

    for epoch in range(20):
        train_loss = 0
        val_loss = 0
        train_steps = 0
        test_steps = 0
        label_list = []
        outputs_list = []
        train_auc = 0
        test_auc = 0
        model.train()
        for image, label in tqdm(dataloader_train):
            image = image.to("cuda:3")
            label = label.to("cuda:3")
            with torch.no_grad():
                image = nn.functional.pad(image, (2,1,2,1))
            #optimizer.zero_grad()
            for param in model.parameters():
                param.grad = None

            with torch.cuda.amp.autocast():
              outputs = model(image)
              loss = criterion(outputs, label.float())
            label_list.append(label.detach().cpu().numpy())
            outputs_list.append(outputs.detach().cpu().numpy())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
            train_steps += 1
            if train_steps%w_intr == 0:
                 wandb.log({"loss": loss.item()})
        with torch.no_grad():
            label_list = straightner(label_list)
            outputs_list = straightner(outputs_list)
            train_auc = metric(label_list, outputs_list) 




        #-------------------------------------------------------------------
        model.eval()
        label_list = []
        outputs_list = []
        with torch.no_grad():
            for image, label in tqdm(dataloader_test):
                image = image.to("cuda:3")
                label = label.to("cuda:3")
                image = nn.functional.pad(image, (2,1,2,1))
                outputs = model(image)
                loss = criterion(outputs, label.float())
                label_list.append(label.detach().cpu().numpy())
                outputs_list.append(outputs.detach().cpu().numpy())
                val_loss += loss.item()
                test_steps +=1
                if test_steps%w_intr == 0:
                 wandb.log({"val_loss": loss.item()})
            label_list = straightner(label_list)
            outputs_list = straightner(outputs_list)
            test_auc = metric(label_list, outputs_list)

        train_loss = train_loss/train_steps
        val_loss = val_loss/ test_steps
        
        print("----------------------------------------------------")
        print("Epoch No" , epoch)
        print("The Training loss of the epoch, ",train_loss)
        print("The Training AUC of the epoch,  %.5f"%train_auc)
        print("The validation loss of the epoch, ",val_loss)
        print("The validation AUC of the epoch, %.5f"%test_auc)
        print("----------------------------------------------------")
        PATH = f"model_Epoch_{epoch}.pt"
#         torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'scheduler': scheduler.state_dict()
#                 }, PATH)
        scheduler.step(test_auc)
        curr_lr = scheduler._last_lr[0]
        wandb.log({"Train_auc_epoch": train_auc,
                  "Epoch": epoch,
                  "Val_auc_epoch": test_auc,
                  "Train_loss_epoch": train_loss,
                  "Val_loss_epoch": val_loss,
                  "Lr": curr_lr}
                 )
        gc.collect()
    
    wandb.finish()


In [None]:
trainer()

[34m[1mwandb[0m: Currently logged in as: [33mdc250601[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/diptarko/.netrc


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:48<00:00,  1.20s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:18<00:00,  2.19it/s]


----------------------------------------------------
Epoch No 0
The Training loss of the epoch,  0.5807615502127286
The Training AUC of the epoch,  0.76460
The validation loss of the epoch,  0.5717705336110345
The validation AUC of the epoch, 0.77521
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:58<00:00,  1.21s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:29<00:00,  2.07it/s]


----------------------------------------------------
Epoch No 1
The Training loss of the epoch,  0.5617238511127988
The Training AUC of the epoch,  0.78401
The validation loss of the epoch,  0.5612296441505695
The validation AUC of the epoch, 0.78677
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:57<00:00,  1.21s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:29<00:00,  2.07it/s]


----------------------------------------------------
Epoch No 2
The Training loss of the epoch,  0.5559861934733117
The Training AUC of the epoch,  0.78960
The validation loss of the epoch,  0.5604981316232133
The validation AUC of the epoch, 0.78865
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [35:01<00:00,  1.21s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:30<00:00,  2.07it/s]


----------------------------------------------------
Epoch No 3
The Training loss of the epoch,  0.5520208834916696
The Training AUC of the epoch,  0.79324
The validation loss of the epoch,  0.554742815850795
The validation AUC of the epoch, 0.79155
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:57<00:00,  1.21s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:30<00:00,  2.07it/s]


----------------------------------------------------
Epoch No 4
The Training loss of the epoch,  0.5479065686293032
The Training AUC of the epoch,  0.79703
The validation loss of the epoch,  0.5594397825756292
The validation AUC of the epoch, 0.79422
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [35:03<00:00,  1.21s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:30<00:00,  2.07it/s]


----------------------------------------------------
Epoch No 5
The Training loss of the epoch,  0.5441983305688562
The Training AUC of the epoch,  0.80046
The validation loss of the epoch,  0.5509804507096608
The validation AUC of the epoch, 0.79577
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:57<00:00,  1.21s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:28<00:00,  2.09it/s]


----------------------------------------------------
Epoch No 6
The Training loss of the epoch,  0.5402696353265609
The Training AUC of the epoch,  0.80408
The validation loss of the epoch,  0.5510548414855168
The validation AUC of the epoch, 0.79500
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:51<00:00,  1.20s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:28<00:00,  2.09it/s]


----------------------------------------------------
Epoch No 7
The Training loss of the epoch,  0.5355786830186844
The Training AUC of the epoch,  0.80820
The validation loss of the epoch,  0.5695691166938036
The validation AUC of the epoch, 0.79406
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:31<00:00,  1.19s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:25<00:00,  2.12it/s]


----------------------------------------------------
Epoch No 8
The Training loss of the epoch,  0.5280031548320562
The Training AUC of the epoch,  0.81468
The validation loss of the epoch,  0.5695825976886969
The validation AUC of the epoch, 0.79266
----------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:20<00:00,  1.18s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:24<00:00,  2.12it/s]


----------------------------------------------------
Epoch No 9
The Training loss of the epoch,  0.5170115168581064
The Training AUC of the epoch,  0.82383
The validation loss of the epoch,  0.5599908828735352
The validation AUC of the epoch, 0.79076
----------------------------------------------------
Epoch 00010: reducing learning rate of group 0 to 5.0000e-04.


100%|███████████████████████████████████████████████████████████████████████████████| 1740/1740 [34:41<00:00,  1.20s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 435/435 [03:26<00:00,  2.10it/s]


----------------------------------------------------
Epoch No 10
The Training loss of the epoch,  0.48044726853740627
The Training AUC of the epoch,  0.85136
The validation loss of the epoch,  0.5939084420258971
The validation AUC of the epoch, 0.78201
----------------------------------------------------


 74%|██████████████████████████████████████████████████████████                     | 1280/1740 [25:32<09:08,  1.19s/it]