In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import sys
os.chdir('/content/drive/MyDrive/Seed_Classification_4_varieties/resnet1D')
os.getcwd()

'/content/drive/MyDrive/Seed_Classification_4_varieties/resnet1D'

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import imblearn.over_sampling as oversample
from  model import *
from  train_eval import *
import torch
from torchsummary import summary
import torch.optim as optim
import pickle
from sklearn.metrics import confusion_matrix, classification_report


In [None]:
BASE_DIR = {
    'crease_up' : '/content/drive/MyDrive/Seed_Classification_4_varieties/crease_up/Data' ,
    'crease_down' : '/content/drive/MyDrive/Seed_Classification_4_varieties/crease_down/Data' ,
    'crease_up_down_combined' : '/content/drive/MyDrive/Seed_Classification_4_varieties/crease_up_down_combined/Data'
}
config = {
    'BATCH_SIZE' : 64,
    'lr' : 0.000008,
    'EPOCHS' : 50,
    'input_size' : 147,
    'output_size' : 4,

    'val_path' : 'Data/df_val.csv',
    'tst_path' : 'Data/df_tst.csv' ,
}

SAVE_DIR = '/content/drive/MyDrive/Seed_Classification_4_varieties/resnet1D'
def update_config(data_type:str):
  config['val_path'] = f"{BASE_DIR[data_type]}/df_val.csv"
  config['tst_path'] = f"{BASE_DIR[data_type]}/df_tst.csv"
  config['input_size'] = 294 if data_type == 'crease_up_down_combined' else 147

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# df_tst = pd.read_csv(config['tst_path'])

In [None]:
model = ResNet1D(num_classes = config['output_size'] , input_length=config['input_size']).to(device)
summary(model, (config['input_size'],))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1               [-1, 16, 74]             128
       BatchNorm1d-2               [-1, 16, 74]              32
             PReLU-3               [-1, 16, 74]               1
         MaxPool1d-4               [-1, 16, 37]               0
            Conv1d-5               [-1, 32, 19]           1,568
       BatchNorm1d-6               [-1, 32, 19]              64
             PReLU-7               [-1, 32, 19]               1
            Conv1d-8               [-1, 32, 19]           3,104
       BatchNorm1d-9               [-1, 32, 19]              64
           Conv1d-10               [-1, 32, 19]             544
      BatchNorm1d-11               [-1, 32, 19]              64
            PReLU-12               [-1, 32, 19]               1
  ResidualBlock1D-13               [-1, 32, 19]               0
           Conv1d-14               [-1,

## Plotting


In [None]:
def plot_history(history:dict, model_name):
      plt.figure( figsize=(20,10))

    #   plt.subplot(2, 1, 1)
      plt.plot(history['train_loss'], 'o-', color='green')
      plt.plot(history['val_loss'], 'o-', color='red')
      plt.plot(history['train_acc'], 'o-', color='pink')
      plt.plot(history['val_acc'], 'o-', color='blue')
    #   plt.ylabel('Loss')
      plt.xlabel('Epochs', fontsize = 14)
      plt.legend(['Train Loss' , 'Validation Loss','Train Accuracy' , 'Validation Accuracy'],  fontsize = 14)


      plt.savefig('{x}_Train_Val_Curves.png'.format(x = model_name), bbox_inches='tight')


## Imbalanced Data Training

In [None]:
for data_type in ['crease_up', 'crease_down', 'crease_up_down_combined']:
    update_config(data_type)

    config['tr_path'] = f'{BASE_DIR[data_type]}/df_tr_imbalanced.csv'
    config['name'] = f'{data_type}_imbalanced'

    df_tr = pd.read_csv(config['tr_path'])
    df_tst = pd.read_csv(config['tst_path'])

    # Using iloc for positional indexing
    y = np.array(df_tr.iloc[:, -1])  # Target variable (last column)

    # Using iloc for the features
    mean = torch.tensor(np.mean(df_tr.iloc[:, :-1].values, axis=0), dtype=torch.float32)
    std = torch.tensor(np.std(df_tr.iloc[:, :-1].values, axis=0), dtype=torch.float32)

    print(mean.shape, std.shape)

    if not os.path.exists(f'{SAVE_DIR}/{data_type}_imbalanced_model.pt'):
      tr_dataset = MyDataset(path =config ['tr_path'] , mean = mean ,std = std , apply_transform = True)
      tr_loader = DataLoader(tr_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)

      val_dataset = MyDataset(path = config['val_path'] , mean = mean ,std = std, apply_transform = True)
      val_loader = DataLoader(val_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)

      X,Y = next(iter(tr_loader))
      print(X.shape , Y.shape)

      model = ResNet1D(num_classes = config['output_size'] , input_length=config['input_size']).to(device)
      optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
      criterion = nn.CrossEntropyLoss()    # weight = class_weights.to(device)

      # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5,cooldown= 2, factor = 0.4, verbose = True, threshold = 1e-2)

      lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.975, last_epoch=-1)

      obj = Train_Eval(model=model, model_name=config['name'] ,device=device ,train_loader=tr_loader,
                    val_loader=val_loader, optimizer=optimizer, criterion=criterion, lr_scheduler= lr_scheduler)
      obj.run(config['EPOCHS'])
      obj.eval(df_tst, mean, std)

      del model
      del optimizer
      del criterion
      del lr_scheduler
      del obj
      torch.cuda.empty_cache()


torch.Size([147]) torch.Size([147])
torch.Size([147]) torch.Size([147])
torch.Size([294]) torch.Size([294])


## SMOTE-Balanced Training

In [None]:
for data_type in ['crease_up', 'crease_down', 'crease_up_down_combined']:
    update_config(data_type)

    config['tr_path'] = f'{BASE_DIR[data_type]}/df_tr_smote.csv'
    config['name'] = f'{data_type}_smote'

    df_tr = pd.read_csv(config['tr_path'])
    df_tst = pd.read_csv(config['tst_path'])
    # Using iloc for positional indexing
    y = np.array(df_tr.iloc[:, -1])  # Target variable (last column)

    # Using iloc for the features
    mean = torch.tensor(np.mean(df_tr.iloc[:, :-1].values, axis=0), dtype=torch.float32)
    std = torch.tensor(np.std(df_tr.iloc[:, :-1].values, axis=0), dtype=torch.float32)

    print(mean.shape, std.shape)

    if not os.path.exists(f'{SAVE_DIR}/{data_type}_smote_model.pt'):
      tr_dataset = MyDataset(path =config ['tr_path'] , mean = mean ,std = std , apply_transform = True)
      tr_loader = DataLoader(tr_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)

      val_dataset = MyDataset(path = config['val_path'] , mean = mean ,std = std, apply_transform = True)
      val_loader = DataLoader(val_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)

      X,Y = next(iter(tr_loader))
      print(X.shape , Y.shape)

      model = ResNet1D(num_classes = config['output_size'] , input_length=config['input_size']).to(device)
      optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
      criterion = nn.CrossEntropyLoss()    # weight = class_weights.to(device)

      # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5,cooldown= 2, factor = 0.4, verbose = True, threshold = 1e-2)

      lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99, last_epoch=-1)

      obj = Train_Eval(model=model, model_name=config['name'] ,device=device ,train_loader=tr_loader,
                    val_loader=val_loader, optimizer=optimizer, criterion=criterion, lr_scheduler= lr_scheduler)
      obj.run(config['EPOCHS'])
      obj.eval(df_tst, mean, std)

      del model
      del optimizer
      del criterion
      del lr_scheduler
      del obj
      torch.cuda.empty_cache()


torch.Size([147]) torch.Size([147])
torch.Size([147]) torch.Size([147])
torch.Size([294]) torch.Size([294])
torch.Size([64, 294]) torch.Size([64])

Epoch: 0


100%|██████████| 661/661 [00:56<00:00, 11.72it/s]
100%|██████████| 64/64 [00:01<00:00, 60.84it/s]


Epoch [0] --> LossTr: 1.2364    AccTr: 0.5469    lossVal : 1.1041     accVal : 0.6611

Detected network improvement, saving current model  "✅"

Epoch: 1


100%|██████████| 661/661 [00:57<00:00, 11.56it/s]
100%|██████████| 64/64 [00:01<00:00, 61.35it/s]


Epoch [1] --> LossTr: 1.0542    AccTr: 0.7180    lossVal : 1.0068     accVal : 0.7490

Detected network improvement, saving current model  "✅"

Epoch: 2


100%|██████████| 661/661 [01:00<00:00, 10.87it/s]
100%|██████████| 64/64 [00:01<00:00, 61.21it/s]


Epoch [2] --> LossTr: 0.9700    AccTr: 0.7918    lossVal : 0.9603     accVal : 0.7920

Detected network improvement, saving current model  "✅"

Epoch: 3


100%|██████████| 661/661 [00:54<00:00, 12.05it/s]
100%|██████████| 64/64 [00:01<00:00, 59.01it/s]


Epoch [3] --> LossTr: 0.9284    AccTr: 0.8285    lossVal : 0.9275     accVal : 0.8206

Detected network improvement, saving current model  "✅"

Epoch: 4


100%|██████████| 661/661 [00:54<00:00, 12.16it/s]
100%|██████████| 64/64 [00:01<00:00, 45.48it/s]


Epoch [4] --> LossTr: 0.9015    AccTr: 0.8528    lossVal : 0.9063     accVal : 0.8440

Detected network improvement, saving current model  "✅"

Epoch: 5


100%|██████████| 661/661 [00:57<00:00, 11.49it/s]
100%|██████████| 64/64 [00:01<00:00, 60.67it/s]


Epoch [5] --> LossTr: 0.8804    AccTr: 0.8743    lossVal : 0.8922     accVal : 0.8574

Detected network improvement, saving current model  "✅"

Epoch: 6


100%|██████████| 661/661 [00:57<00:00, 11.51it/s]
100%|██████████| 64/64 [00:01<00:00, 60.22it/s]


Epoch [6] --> LossTr: 0.8672    AccTr: 0.8855    lossVal : 0.8803     accVal : 0.8694

Detected network improvement, saving current model  "✅"

Epoch: 7


100%|██████████| 661/661 [00:54<00:00, 12.03it/s]
100%|██████████| 64/64 [00:01<00:00, 41.46it/s]


Epoch [7] --> LossTr: 0.8569    AccTr: 0.8949    lossVal : 0.8721     accVal : 0.8738


Epoch: 8


100%|██████████| 661/661 [00:53<00:00, 12.37it/s]
100%|██████████| 64/64 [00:01<00:00, 59.28it/s]


Epoch [8] --> LossTr: 0.8487    AccTr: 0.9023    lossVal : 0.8648     accVal : 0.8845

Detected network improvement, saving current model  "✅"

Epoch: 9


100%|██████████| 661/661 [00:55<00:00, 11.95it/s]
100%|██████████| 64/64 [00:01<00:00, 60.68it/s]


Epoch [9] --> LossTr: 0.8416    AccTr: 0.9088    lossVal : 0.8597     accVal : 0.8870


Epoch: 10


100%|██████████| 661/661 [00:55<00:00, 11.88it/s]
100%|██████████| 64/64 [00:01<00:00, 61.10it/s]


Epoch [10] --> LossTr: 0.8370    AccTr: 0.9116    lossVal : 0.8502     accVal : 0.8967

Detected network improvement, saving current model  "✅"

Epoch: 11


100%|██████████| 661/661 [00:56<00:00, 11.70it/s]
100%|██████████| 64/64 [00:01<00:00, 37.94it/s]


Epoch [11] --> LossTr: 0.8318    AccTr: 0.9167    lossVal : 0.8465     accVal : 0.8997


Epoch: 12


100%|██████████| 661/661 [00:56<00:00, 11.79it/s]
100%|██████████| 64/64 [00:01<00:00, 59.89it/s]


Epoch [12] --> LossTr: 0.8289    AccTr: 0.9190    lossVal : 0.8445     accVal : 0.9019


Epoch: 13


100%|██████████| 661/661 [00:56<00:00, 11.70it/s]
100%|██████████| 64/64 [00:01<00:00, 60.34it/s]


Epoch [13] --> LossTr: 0.8244    AccTr: 0.9230    lossVal : 0.8399     accVal : 0.9058

Detected network improvement, saving current model  "✅"

Epoch: 14


100%|██████████| 661/661 [00:56<00:00, 11.75it/s]
100%|██████████| 64/64 [00:01<00:00, 60.03it/s]


Epoch [14] --> LossTr: 0.8207    AccTr: 0.9270    lossVal : 0.8356     accVal : 0.9099


Epoch: 15


100%|██████████| 661/661 [00:55<00:00, 11.94it/s]
100%|██████████| 64/64 [00:01<00:00, 39.48it/s]


Epoch [15] --> LossTr: 0.8178    AccTr: 0.9300    lossVal : 0.8344     accVal : 0.9109


Epoch: 16


100%|██████████| 661/661 [00:57<00:00, 11.47it/s]
100%|██████████| 64/64 [00:01<00:00, 43.11it/s]


Epoch [16] --> LossTr: 0.8162    AccTr: 0.9315    lossVal : 0.8296     accVal : 0.9165

Detected network improvement, saving current model  "✅"

Epoch: 17


100%|██████████| 661/661 [00:58<00:00, 11.35it/s]
100%|██████████| 64/64 [00:01<00:00, 59.18it/s]


Epoch [17] --> LossTr: 0.8128    AccTr: 0.9352    lossVal : 0.8294     accVal : 0.9148


Epoch: 18


100%|██████████| 661/661 [00:56<00:00, 11.73it/s]
100%|██████████| 64/64 [00:01<00:00, 59.22it/s]


Epoch [18] --> LossTr: 0.8100    AccTr: 0.9370    lossVal : 0.8238     accVal : 0.9248


Epoch: 19


100%|██████████| 661/661 [00:55<00:00, 11.99it/s]
100%|██████████| 64/64 [00:01<00:00, 37.95it/s]


Epoch [19] --> LossTr: 0.8086    AccTr: 0.9388    lossVal : 0.8241     accVal : 0.9221


Epoch: 20


100%|██████████| 661/661 [00:54<00:00, 12.02it/s]
100%|██████████| 64/64 [00:01<00:00, 59.84it/s]


Epoch [20] --> LossTr: 0.8061    AccTr: 0.9410    lossVal : 0.8226     accVal : 0.9248


Epoch: 21


100%|██████████| 661/661 [00:58<00:00, 11.30it/s]
100%|██████████| 64/64 [00:01<00:00, 57.78it/s]


Epoch [21] --> LossTr: 0.8046    AccTr: 0.9419    lossVal : 0.8191     accVal : 0.9263

Detected network improvement, saving current model  "✅"

Epoch: 22


100%|██████████| 661/661 [00:56<00:00, 11.61it/s]
100%|██████████| 64/64 [00:01<00:00, 59.53it/s]


Epoch [22] --> LossTr: 0.8025    AccTr: 0.9443    lossVal : 0.8168     accVal : 0.9282


Epoch: 23


100%|██████████| 661/661 [00:55<00:00, 11.85it/s]
100%|██████████| 64/64 [00:01<00:00, 37.31it/s]


Epoch [23] --> LossTr: 0.8006    AccTr: 0.9465    lossVal : 0.8149     accVal : 0.9319


Epoch: 24


100%|██████████| 661/661 [00:54<00:00, 12.06it/s]
100%|██████████| 64/64 [00:01<00:00, 60.11it/s]


Epoch [24] --> LossTr: 0.7986    AccTr: 0.9484    lossVal : 0.8122     accVal : 0.9341


Epoch: 25


100%|██████████| 661/661 [00:56<00:00, 11.74it/s]
100%|██████████| 64/64 [00:01<00:00, 58.16it/s]


Epoch [25] --> LossTr: 0.7986    AccTr: 0.9475    lossVal : 0.8117     accVal : 0.9324


Epoch: 26


100%|██████████| 661/661 [00:58<00:00, 11.27it/s]
100%|██████████| 64/64 [00:01<00:00, 58.94it/s]


Epoch [26] --> LossTr: 0.7967    AccTr: 0.9501    lossVal : 0.8108     accVal : 0.9351


Epoch: 27


100%|██████████| 661/661 [00:56<00:00, 11.64it/s]
100%|██████████| 64/64 [00:01<00:00, 58.96it/s]


Epoch [27] --> LossTr: 0.7947    AccTr: 0.9514    lossVal : 0.8092     accVal : 0.9375


Epoch: 28


100%|██████████| 661/661 [00:55<00:00, 12.00it/s]
100%|██████████| 64/64 [00:01<00:00, 39.74it/s]


Epoch [28] --> LossTr: 0.7938    AccTr: 0.9523    lossVal : 0.8066     accVal : 0.9409

Detected network improvement, saving current model  "✅"

Epoch: 29


100%|██████████| 661/661 [00:55<00:00, 11.99it/s]
100%|██████████| 64/64 [00:01<00:00, 59.92it/s]


Epoch [29] --> LossTr: 0.7934    AccTr: 0.9526    lossVal : 0.8048     accVal : 0.9404


Epoch: 30


100%|██████████| 661/661 [00:56<00:00, 11.75it/s]
100%|██████████| 64/64 [00:01<00:00, 59.82it/s]


Epoch [30] --> LossTr: 0.7916    AccTr: 0.9544    lossVal : 0.8053     accVal : 0.9407


Epoch: 31


100%|██████████| 661/661 [00:57<00:00, 11.45it/s]
100%|██████████| 64/64 [00:01<00:00, 58.90it/s]


Epoch [31] --> LossTr: 0.7907    AccTr: 0.9552    lossVal : 0.8041     accVal : 0.9417


Epoch: 32


100%|██████████| 661/661 [00:56<00:00, 11.75it/s]
100%|██████████| 64/64 [00:01<00:00, 59.50it/s]


Epoch [32] --> LossTr: 0.7897    AccTr: 0.9561    lossVal : 0.8029     accVal : 0.9419


Epoch: 33


100%|██████████| 661/661 [00:55<00:00, 11.84it/s]
100%|██████████| 64/64 [00:02<00:00, 30.17it/s]


Epoch [33] --> LossTr: 0.7880    AccTr: 0.9582    lossVal : 0.8007     accVal : 0.9441


Epoch: 34


100%|██████████| 661/661 [00:55<00:00, 11.82it/s]
100%|██████████| 64/64 [00:01<00:00, 58.76it/s]


Epoch [34] --> LossTr: 0.7876    AccTr: 0.9579    lossVal : 0.8012     accVal : 0.9434


Epoch: 35


100%|██████████| 661/661 [00:56<00:00, 11.67it/s]
100%|██████████| 64/64 [00:01<00:00, 58.75it/s]


Epoch [35] --> LossTr: 0.7861    AccTr: 0.9601    lossVal : 0.8009     accVal : 0.9448


Epoch: 36


100%|██████████| 661/661 [00:58<00:00, 11.23it/s]
100%|██████████| 64/64 [00:01<00:00, 59.35it/s]


Epoch [36] --> LossTr: 0.7849    AccTr: 0.9607    lossVal : 0.7991     accVal : 0.9456


Epoch: 37


100%|██████████| 661/661 [00:56<00:00, 11.63it/s]
100%|██████████| 64/64 [00:01<00:00, 59.32it/s]


Epoch [37] --> LossTr: 0.7847    AccTr: 0.9610    lossVal : 0.7973     accVal : 0.9487


Epoch: 38


100%|██████████| 661/661 [00:57<00:00, 11.58it/s]
100%|██████████| 64/64 [00:01<00:00, 59.95it/s]


Epoch [38] --> LossTr: 0.7833    AccTr: 0.9624    lossVal : 0.7967     accVal : 0.9490


Epoch: 39


100%|██████████| 661/661 [00:55<00:00, 11.97it/s]
100%|██████████| 64/64 [00:01<00:00, 38.11it/s]


Epoch [39] --> LossTr: 0.7839    AccTr: 0.9614    lossVal : 0.7980     accVal : 0.9473


Epoch: 40


100%|██████████| 661/661 [00:57<00:00, 11.46it/s]
100%|██████████| 64/64 [00:01<00:00, 39.71it/s]


Epoch [40] --> LossTr: 0.7829    AccTr: 0.9626    lossVal : 0.7954     accVal : 0.9487

Detected network improvement, saving current model  "✅"

Epoch: 41


100%|██████████| 661/661 [00:55<00:00, 11.82it/s]
100%|██████████| 64/64 [00:01<00:00, 58.75it/s]


Epoch [41] --> LossTr: 0.7822    AccTr: 0.9636    lossVal : 0.7973     accVal : 0.9458


Epoch: 42


100%|██████████| 661/661 [00:56<00:00, 11.66it/s]
100%|██████████| 64/64 [00:01<00:00, 59.64it/s]


Epoch [42] --> LossTr: 0.7811    AccTr: 0.9646    lossVal : 0.7942     accVal : 0.9517


Epoch: 43


100%|██████████| 661/661 [00:56<00:00, 11.74it/s]
100%|██████████| 64/64 [00:01<00:00, 59.36it/s]


Epoch [43] --> LossTr: 0.7795    AccTr: 0.9661    lossVal : 0.7929     accVal : 0.9521


Epoch: 44


100%|██████████| 661/661 [00:56<00:00, 11.75it/s]
100%|██████████| 64/64 [00:01<00:00, 58.07it/s]


Epoch [44] --> LossTr: 0.7791    AccTr: 0.9667    lossVal : 0.7921     accVal : 0.9531


Epoch: 45


100%|██████████| 661/661 [00:57<00:00, 11.59it/s]
100%|██████████| 64/64 [00:01<00:00, 38.86it/s]


Epoch [45] --> LossTr: 0.7788    AccTr: 0.9666    lossVal : 0.7922     accVal : 0.9512


Epoch: 46


100%|██████████| 661/661 [00:55<00:00, 11.81it/s]
100%|██████████| 64/64 [00:01<00:00, 59.37it/s]


Epoch [46] --> LossTr: 0.7784    AccTr: 0.9665    lossVal : 0.7933     accVal : 0.9495


Epoch: 47


100%|██████████| 661/661 [00:56<00:00, 11.67it/s]
100%|██████████| 64/64 [00:01<00:00, 59.42it/s]


Epoch [47] --> LossTr: 0.7770    AccTr: 0.9684    lossVal : 0.7914     accVal : 0.9556


Epoch: 48


100%|██████████| 661/661 [00:56<00:00, 11.68it/s]
100%|██████████| 64/64 [00:01<00:00, 59.37it/s]


Epoch [48] --> LossTr: 0.7771    AccTr: 0.9684    lossVal : 0.7924     accVal : 0.9504


Epoch: 49


100%|██████████| 661/661 [00:56<00:00, 11.80it/s]
100%|██████████| 64/64 [00:01<00:00, 37.55it/s]


Epoch [49] --> LossTr: 0.7764    AccTr: 0.9686    lossVal : 0.7903     accVal : 0.9546

acc : 0.951733   
{0: 3030, 1: 3030, 2: 1414, 3: 606}


## Adasyn-Data Training

In [None]:
for data_type in ['crease_up', 'crease_down', 'crease_up_down_combined']:
    update_config(data_type)

    config['tr_path'] = f'{BASE_DIR[data_type]}/df_tr_adasyn.csv'
    config['name'] = f'{data_type}_adasyn'

    df_tr = pd.read_csv(config['tr_path'])
    df_tst = pd.read_csv(config['tst_path'])
    # Using iloc for positional indexing
    y = np.array(df_tr.iloc[:, -1])  # Target variable (last column)

    # Using iloc for the features
    mean = torch.tensor(np.mean(df_tr.iloc[:, :-1].values, axis=0), dtype=torch.float32)
    std = torch.tensor(np.std(df_tr.iloc[:, :-1].values, axis=0), dtype=torch.float32)

    print(mean.shape, std.shape)

    if not os.path.exists(f'{SAVE_DIR}/{data_type}_adasyn_model.pt'):
      tr_dataset = MyDataset(path =config ['tr_path'] , mean = mean ,std = std , apply_transform = True)
      tr_loader = DataLoader(tr_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)

      val_dataset = MyDataset(path = config['val_path'] , mean = mean ,std = std, apply_transform = True)
      val_loader = DataLoader(val_dataset, batch_size=config['BATCH_SIZE'], shuffle=True)

      X,Y = next(iter(tr_loader))
      print(X.shape , Y.shape)

      model = ResNet1D(num_classes = config['output_size'] , input_length=config['input_size']).to(device)
      optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
      criterion = nn.CrossEntropyLoss()    # weight = class_weights.to(device)

      # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5,cooldown= 2, factor = 0.4, verbose = True, threshold = 1e-2)

      lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99, last_epoch=-1)

      obj = Train_Eval(model=model, model_name=config['name'] ,device=device ,train_loader=tr_loader,
                    val_loader=val_loader, optimizer=optimizer, criterion=criterion, lr_scheduler= lr_scheduler)
      obj.run(config['EPOCHS'])
      obj.eval(df_tst, mean, std)

      del model
      del optimizer
      del criterion
      del lr_scheduler
      del obj
      torch.cuda.empty_cache()


torch.Size([147]) torch.Size([147])
torch.Size([147]) torch.Size([147])
torch.Size([294]) torch.Size([294])
