In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import random
import math
import time
import torch.nn.functional as F

from tqdm.notebook import tqdm
from torch import optim
from torch.utils.data import DataLoader, Dataset,  WeightedRandomSampler
from  scipy import stats
import scipy
import numpy as np
np.random.seed(1)
torch.manual_seed(1)
random.seed(1)

In [3]:
import os
os.chdir("/content/drive/MyDrive/MADE/Project/deap")

In [4]:
def get_padding(in_size, kernel_size, stride):
    if (in_size % stride == 0):
        padding = max(kernel_size - stride, 0)
    else:
        padding = max(kernel_size - (in_size % stride), 0)
    return (padding)    

In [5]:
def get_temporal_feature_extractor(input_size):
  pad = get_padding(input_size, 5,  2)
  conv1 = nn.Conv2d(1, 32, kernel_size = (1, 5), stride=(1, 2), padding=(0, pad))
  relu1 = nn.LeakyReLU()
  pad = get_padding(input_size, 3,  2)
  conv2 = nn.Conv2d(32, 32, kernel_size = (1, 3), stride=(1, 2), padding=(0, pad))
  relu2 = nn.LeakyReLU()
  conv3 = nn.Conv2d(32, 32, kernel_size = (1, 3), stride=(1, 2), padding=(0, pad))
  relu3 = nn.LeakyReLU()
  conv4 = nn.Conv2d(32, 32, kernel_size = (1, 16), stride=(1, 16), padding=0)
  relu4 = nn.LeakyReLU()
  #print("11")
  result = torch.nn.Sequential(conv1, relu1, conv2, relu2, conv3, relu3, conv4, relu4)
  #print("22")
  #print(result)
  return(result)

In [6]:
import glob
import pickle
from sklearn.preprocessing import MinMaxScaler

data = []
labels = []
data_dir = './data_preprocessed_python'
files = glob.glob(os.path.join(data_dir, "*.dat"))
data_raw = []
for file_data in files:
    raw_data = pickle.load(open(file_data, 'rb'), encoding='latin1')
    data.append(raw_data['data'])
    #data_raw.append(raw_data['data'][:, :, :])
    # data_raw[-1][0, :32, 0]
    # print(data_raw[-1][:, :31, :].min())
    # print(data_raw[-1][:, :31, :].max())
    # scaler = MinMaxScaler()
    # for i in range(40):
    #     #scaler.fit(data[-1][i])
    #     scaler = MinMaxScaler()
    #     data[-1][i] = scaler.fit_transform(data[-1][i])
    #     print(data_raw[-1][i])
    #     print(data[-1][i])
    #     break
    labels.append(raw_data['labels'])
    #break

In [7]:
class EmotionNet(torch.nn.Module): 
   def __init__(self, hcanals, wcanals, nfeatures, ntimes_in_sample):
      super().__init__()
      #print("1")
      self.tfe = get_temporal_feature_extractor(ntimes_in_sample) #(bs, 1, h = 9, w = 9, s = 128) -> (bs, h = 9, w = 9, s = 1)
      self.flat = nn.Flatten(1, 2)
      self.input_linear_size = 32 * 32#int(hcanals * (wcanals//2)* nfeatures * 2 + hcanals * (wcanals)* nfeatures)
      self.fc1 = nn.Linear(self.input_linear_size, 20)
      self.relu1 = nn.LeakyReLU()
      self.drop = nn.Dropout(0.3)
      self.fc2 = nn.Linear(20, 1)
   def forward(self, input):
      input = input.unsqueeze(1)
      #print(f"input_shape = {input.shape}")
      #input (bs, in_canals = 1,  h=9, w=9, s=128)
      output_tfe = self.tfe(input)
      #print(f"output_tfe.shape = {output_tfe.shape}")
      #output_tfe (bs, in_canals = 32,  h=9, w=9, s=1)
      output_tfe = output_tfe.squeeze(3)
      #print(f"output_tfe.shape = {output_tfe.shape}")
      
      output_tfe = output_tfe.permute(0, 2, 1) 
      output_tfe_flatten = self.flat(output_tfe)
      #print(f"output_tfe.shape = {output_tfe.shape}")
      output1 = self.fc1(output_tfe_flatten)
      #print(f"output1.shape = {output1.shape}")
      output1_relu = self.relu1(output1)
      #print(f"output1_relu.shape = {output1_relu.shape}")
      output2 = self.fc2(output1_relu)
      #print(f"output2.shape = {output2.shape}")
      return output2


In [8]:
LEN_RECORD_IN_SECONDS = 60
NVIDEOS = 40
HCANALS = 9
WCANALS = 9
NTIMES_IN_SAMPLE = 128
NTIMES_IN_SEC = 128
NCANALS = 32
NFEATURES = 32
electrode_matrix = {}
electrode_matrix['FP1'] = [0, 3]
electrode_matrix['FP2'] = [0, 5]
electrode_matrix['AF3'] = [1, 3]
electrode_matrix['AF4'] = [1, 5]
electrode_matrix['F7']  = [2, 0]
electrode_matrix['F3']  = [2, 2]
electrode_matrix['FZ']  = [2, 4]
electrode_matrix['F4']  = [2, 6]
electrode_matrix['F8']  = [2, 8]
electrode_matrix['FC5']  = [3, 1]
electrode_matrix['FC1']  = [3, 3]
electrode_matrix['FC2']  = [3, 5]
electrode_matrix['FC6']  = [3, 7]
electrode_matrix['T7']  = [4, 0]
electrode_matrix['C3']  = [4, 2]
electrode_matrix['CZ']  = [4, 4]
electrode_matrix['C4']  = [4, 6]
electrode_matrix['T8']  = [4, 8]
electrode_matrix['CP5']  = [5, 1]
electrode_matrix['CP1']  = [5, 3]
electrode_matrix['CP2']  = [5, 5]
electrode_matrix['CP6']  = [5, 7]
electrode_matrix['P7']  = [6, 0]
electrode_matrix['P3']  = [6, 2]
electrode_matrix['PZ']  = [6, 4]
electrode_matrix['P4']  = [6, 6]
electrode_matrix['P8']  = [6, 8]
electrode_matrix['PO3'] = [7, 3]
electrode_matrix['PO4'] = [7, 5]
electrode_matrix['O1'] = [8, 3]
electrode_matrix['OZ'] = [8, 4]
electrode_matrix['O2'] = [8, 5]

list_electrodes = ['FP1', 'AF3', 'F3', 'F7', 'FC5', 'FC1', 'C3',	'T7',	'CP5',	'CP1',	'P3',	'P7',	'PO3',	'O1',	'OZ',	'PZ',	'FP2',	'AF4', 'FZ', 'F4', 'F8', 'FC6',	'FC2',	'CZ', 'C4', 'T8', 'CP6',	'CP2',	'P4', 	'P8',	'PO4',	'O2']
data_dir = './data_preprocessed_python'
TRAIN_SIZE = 0.9
THRESHOLD = 5

In [9]:
import glob
import pickle
from collections import Counter

class EmotionDataset(Dataset):
    def __init__ (self, data_dir, type, ind, data, labels):
       self.data = []
       self.labels = []
       self.cnt = [Counter(), Counter(), Counter(),Counter()]
       #data_dir = './data_preprocessed_python'
       #files = glob.glob(os.path.join(data_dir, "*.dat"))[0:1]
       self.type = type
       #split = int(LEN_RECORD_IN_SECONDS)# *  TRAIN_SIZE)
       self.ind = ind
       self.len_files = []
       for s in range(len(data)):
            #print(file_data)
            #raw_data = pickle.load(open(file_data, 'rb'), encoding='latin1')
            #print(raw_data['data'].shape)
            #labels = raw_data['labels']
            self.data.append(data[s][ind, :, 3 * NTIMES_IN_SEC :LEN_RECORD_IN_SECONDS * NTIMES_IN_SEC + 3 * NTIMES_IN_SEC])
            self.len_files.append(len(ind) *  LEN_RECORD_IN_SECONDS - 1)
            self.len_record = LEN_RECORD_IN_SECONDS
            labels_bin_sub = (labels[s] >= THRESHOLD)
            self.labels.append(labels[s][ind])          
            for i in range(4):
              self.cnt[i].update(list(labels_bin_sub[:, i]))    

            
       self.len_cumsum = np.cumsum(self.len_files)     
       print(self.data[0].shape)
       print(self.labels[0].shape)



    def __len__(self):
        result =  sum(self.len_files) - 10
        return result

    def get_index_record(self, item):
      for i_file in range(len(self.len_cumsum)):
         #print(item, self.len_cumsum[i_file])
         if (item > self.len_cumsum[i_file]):
            continue
         else:
            break
      if i_file == 0:
         index_in_file = item
      else:
         index_in_file = item  - self.len_cumsum[i_file - 1]
      nvideo = index_in_file//(self.len_record)# * LEN_RECORD_IN_SECONDS *  NTIMES_IN_SEC)
      nsec = (index_in_file - nvideo * self.len_record) # *   NTIMES_IN_SEC)

      return i_file, index_in_file, nvideo, nsec

   
    def __getitem__(self, item):
      sample = {}
      #print(item)
      i_file, index_in_file, nvideo, nsec = self.get_index_record(item)
      #print(i_file, index_in_file, nvideo, nsec )
      sample['data'] = np.zeros((32, NTIMES_IN_SAMPLE))

      sample_from_one_canals = []
      for i_canal in range(NCANALS):
        sample_from_one_canal = self.data[i_file][nvideo, i_canal, nsec * 128 : nsec * 128 + 128]
        sample_from_one_canals.append(sample_from_one_canal)
        #print(sample_from_one_canal.shape)
      sample_from_one_canals = np.asarray(sample_from_one_canals).copy()
      sample_from_one_canals = scipy.stats.zscore(sample_from_one_canals, axis = 0)

      for i_canal in range(NCANALS):
        #sample_from_one_canal = torch.FloatTensor(self.data[i_file][nvideo, i_canal, nsec * 128 : nsec * 128 + 128])
        #print(sample_from_one_canal.shape)
        #sample['data'][electrode_matrix[list_electrodes[i_canal]][0],  electrode_matrix[list_electrodes[i_canal]][1]] = sample_from_one_canal
        sample['data'][i_canal] = sample_from_one_canals[i_canal]
      #if self.type == 'train' :
      sample['data'] = torch.FloatTensor(sample['data'])
      sample['labels']  = torch.FloatTensor(self.labels[i_file][nvideo])
          #print(nvideo)
      #else:
      #    sample['labels']  = torch.LongTensor(self.labels[i_file][int(NVIDEOS * TRAIN_SIZE) + nvideo])
          #print(int(NVIDEOS * TRAIN_SIZE) + nvideo)    
      #print(sample)
      return sample


In [23]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


In [15]:
class Args:
  def __init__(self): #(data_path, epoch, batch_siz, image_size, learning_rate, weight_deca, learning_rate, learning_rate_gamma, weight_bce, load, output_dir)
    self.data_path = "/content/drive/MyDrive/MADE/semester2/CV/contest02/data/"
    self.epochs = 2
    self.batch_size = 100
    self.lr= 3e-4
    self.weight_decay= 1e-6
    self.learning_rate=None
    self.learning_rate_gamma=None
    self.weight_bce=1
    self.load=None
    self.output_dir="runs/segmentation_baseline"
    self.data_dir ="./data_preprocessed_python/"# "/content/drive/MyDrive/MADE/Project/train/physionet.org/"
args = Args()    

In [12]:
type_emotion = 0

In [13]:
from sklearn.model_selection import StratifiedKFold 
from sklearn.metrics import f1_score, accuracy_score
k  = 5
labels_bin = []
for i in range(32):
  temp = labels[i] > 4.5
  #print(labels[i])
  #print(temp)
  labels_bin.append(temp)
X = np.arange(40)
y = np.array(labels_bin[0][:, type_emotion])
skf = StratifiedKFold(n_splits=k, random_state=None, shuffle=True)
balanced_split = skf.split(X, y)
for ind_train, ind_test in  balanced_split:
    print(ind_train, ind_test)
    print(sum(labels_bin[0][ind_train, type_emotion]))
    print(sum(labels_bin[0][ind_test, type_emotion]))
    break

[ 1  2  3  5  6  7  8  9 10 12 13 14 15 16 18 20 21 22 23 24 25 26 28 30
 31 32 33 35 36 37 38 39] [ 0  4 11 17 19 27 29 34]
16
4


In [17]:
train_dataset = EmotionDataset(args.data_dir, 'train', ind_train, data, labels)

class_weights_all = [1/train_dataset.cnt[0][i] for i in range(2)]
weights_samples =  [0] * train_dataset.__len__()
for i in range(train_dataset.__len__()):
    i_file, index_in_file, nvideo, nsec = train_dataset.get_index_record(i)
    #print(train_dataset.labels[i_file][nvideo])
    weights_samples[i] = class_weights_all[int(train_dataset.labels[i_file][nvideo, 0] > 4.5)]

weighted_sampler = WeightedRandomSampler(
    weights=weights_samples,
    num_samples=len(weights_samples),
    replacement=True
)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=1,
                              pin_memory=True, shuffle=True, drop_last=True)

val_dataset = EmotionDataset(args.data_dir, 'val', ind_test, data, labels)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=1,
                              pin_memory=True, shuffle=False, drop_last=False)

(32, 40, 7680)
(32, 4)
(8, 40, 7680)
(8, 4)


In [18]:
# train_dataset = EmotionDataset(files[ind_train])
# train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=1,
#                               pin_memory=True, shuffle=True, drop_last=True)


# val_dataset = EmotionDataset(files[ind_val])
# val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=1,
#                               pin_memory=True, shuffle=False, drop_last=False)

In [19]:
criterion = torch.nn.MSELoss()
#optimizer = optim.SGD(model.parameters(), lr=3e-5, momentum = 0.9)#, weight_decay=args.weight_decay)
optimizer = optim.Adam(model.parameters(), lr=3e-4)#, momentum = 0.9)#, weight_decay=args.weight_decay)


In [20]:
print(train_dataset.cnt)
print(val_dataset.cnt)
# print(files[ind_train])
# print(files[ind_val])

[Counter({1.0: 40, 5.04: 32, 9.0: 31, 5.0: 16, 7.08: 14, 5.01: 14, 2.99: 12, 7.09: 12, 7.01: 12, 6.05: 11, 4.99: 11, 4.94: 10, 4.05: 10, 7.0: 10, 4.0: 10, 7.05: 10, 4.97: 10, 7.04: 9, 4.06: 9, 6.09: 9, 4.04: 8, 8.03: 8, 7.03: 8, 4.01: 8, 5.05: 8, 5.03: 8, 8.15: 7, 3.05: 7, 3.97: 7, 6.06: 7, 3.03: 6, 4.96: 6, 6.0: 6, 4.95: 6, 8.09: 6, 5.96: 6, 6.03: 6, 8.05: 6, 5.08: 6, 3.01: 6, 2.92: 6, 2.97: 6, 7.06: 6, 3.0: 6, 1.96: 6, 1.99: 5, 8.01: 5, 4.12: 5, 6.96: 5, 2.01: 5, 6.99: 5, 5.06: 5, 7.1: 5, 7.99: 5, 5.99: 5, 3.99: 5, 6.08: 5, 8.04: 5, 8.1: 4, 2.28: 4, 7.33: 4, 7.15: 4, 6.72: 4, 4.44: 4, 1.86: 4, 8.24: 4, 6.97: 4, 8.0: 4, 7.12: 4, 1.97: 4, 4.15: 4, 3.08: 4, 4.08: 4, 4.17: 4, 8.06: 4, 8.27: 3, 1.95: 3, 4.18: 3, 7.55: 3, 6.79: 3, 1.01: 3, 2.17: 3, 1.92: 3, 6.88: 3, 7.96: 3, 1.91: 3, 4.27: 3, 2.0: 3, 6.82: 3, 3.72: 3, 4.46: 3, 6.15: 3, 4.03: 3, 8.99: 3, 3.91: 3, 6.65: 3, 3.18: 3, 2.82: 3, 7.97: 3, 3.96: 3, 3.04: 3, 6.04: 3, 5.09: 3, 7.44: 2, 7.32: 2, 3.17: 2, 6.81: 2, 7.17: 2, 3.33: 2, 3.2

In [32]:
def train(model, loader, criterion, optimizer, device, val_dataloader, batch = None):
    model.train()
    train_loss = []
    inputs = []
   
    #lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)#, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
    for s, batch in enumerate(tqdm(loader, total=len(loader), desc="training...", position=0 , leave = True)):
            model.train()
            optimizer.zero_grad()
            src  = batch['data'].to(device)
            #print(src)
            trg = batch['labels'][:, 0]
            #print(trg)
          
            levels_pred = model(src)  # B x (2 * NUM_PTS)
            #print(levels_pred.shape)
            levels_pred = levels_pred.squeeze(1).cpu()

            loss = criterion(levels_pred, trg) 

           
           
            #print(levels_pred)
            #print(trg)
            #print((trg > 4.5))
            #print(levels_pred.retain_grad)
            #print(levels_pred)
            train_loss.append(loss.item())
            loss.backward()
            #print(levels_pred.grad)
            optimizer.step()
            # if (s % 100 == 0):
            #     calculate_predictions(model, val_dataloader)
            #     calculate_predictions(model, loader)
            #break
    return np.mean(train_loss)#, mid_outputs


In [30]:
def evaluate(model, loader, criterion, device):
    
    model.eval()
    epoch_loss = 0
    history = []
  
    with torch.no_grad():
    
        for s, batch in enumerate(tqdm(loader, total=len(loader), desc="validating...", position=0 , leave = True)):
            src  = batch['data'].to(device)
            #print(src.shape)
            trg = batch['labels'][:, 0]



            levels_pred = model(src)  # B x (2 * NUM_PTS)
            #print(levels_pred.shape)
            levels_pred = levels_pred.squeeze(1).cpu()

            loss = criterion(levels_pred, trg) 

        

            #trg1 = trg[:, 1:].reshape(-1)
            #output = levels_pred[:, 1:].reshape(-1, levels_pred.shape[-1])
            #print(trg1.shape)
            #print(output.shape)
            #loss = 0
            #print(trg1.shape)
            #print(trg1)
            ##for i in range(OUTPUT_DIM):
              ##  output_class = output[trg1 == i]
              ##  trg_class = trg1[trg1 == i]
                #print(trg_class.shape)
               # if (trg_class.shape[0] != 0):
                    #print(cnt[i], i)
                #    loss += criterion(output_class, trg_class)/trg_class.shape[0]

            epoch_loss += loss.item() 
         
        
    return epoch_loss / s

In [46]:
from sklearn.metrics import accuracy_score, confusion_matrix,classification_report

def calculate_predictions(model, loader):
    model.eval()
    epoch_loss = 0
    history = []
    real = []
    pred = []
    with torch.no_grad():

        for i, batch in enumerate(tqdm(loader, total=len(loader), desc="predicting...", position=0 , leave = True)):
            src  = batch['data'].to(device)
            #print(src.shape)
            trg = batch['labels'][:, 0]
           

            levels_pred = model(src)  # B x (2 * NUM_PTS)
            levels_pred = levels_pred.squeeze(1).cpu()

            #print(levels_pred.shape)
            trg_pred = (levels_pred > 4.5)
            trg = (trg > 4.5)
            #print(trg_pred)
            #print(trg)
            real.extend(trg)
            pred.extend(trg_pred) 

            
        print(accuracy_score(real, pred)) 
        print(confusion_matrix(real, pred))  
        print(classification_report(real, pred))   
        #plt.hist(real)
        return (f1_score(real, pred)) , (accuracy_score(real, pred)) 

In [20]:
def get_model():
  model = EmotionNet(HCANALS, WCANALS, NFEATURES, NTIMES_IN_SAMPLE).to(device)
  return model


In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = get_model()

In [24]:
            
model.apply(init_weights)

EmotionNet(
  (tfe): Sequential(
    (0): Conv2d(1, 32, kernel_size=(1, 5), stride=(1, 2), padding=(0, 3))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv2d(32, 32, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv2d(32, 32, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))
    (5): LeakyReLU(negative_slope=0.01)
    (6): Conv2d(32, 32, kernel_size=(1, 16), stride=(1, 16))
    (7): LeakyReLU(negative_slope=0.01)
  )
  (flat): Flatten(start_dim=1, end_dim=2)
  (fc1): Linear(in_features=1024, out_features=20, bias=True)
  (relu1): LeakyReLU(negative_slope=0.01)
  (drop): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=20, out_features=1, bias=True)
)

In [41]:
criterion = torch.nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=3e-6, momentum = 0.9)#, weight_decay=args.weight_decay)
#optimizer = optim.Adam(model.parameters(), lr=3e-5)#, momentum = 0.9)#, weight_decay=args.weight_decay)

In [42]:
args.epochs = 100
#criterion =  fnn.mse_loss
train_loss_min = 10000
val_loss_min = 10000
#batch = next(iter(train_dataloader))
for epoch in range(args.epochs):
    #logger.info(f"Starting epoch {epoch + 1}/{args.epochs}.")
    
    train_loss = train(model, train_dataloader, criterion, optimizer ,device, val_dataloader)
    #if epoch % 500 == 0:
    print(train_loss)

    if (train_loss < train_loss_min):
        train_loss_min      = train_loss
        torch.save({
                         'model_state_dict': model.state_dict(),
                         'optimizer_state_dict': optimizer.state_dict(),
                       },
                       os.path.join("/content/drive/MyDrive/MADE/Project/RACNN_models_CNN/", "train.tgz")
            )  

    val_loss = evaluate(model, val_dataloader, criterion, device)
    # #break
    print(val_loss)

    # #calculate_predictions(model, val_dataloader)
    if (val_loss < val_loss_min):
         val_loss_min      = val_loss
         torch.save({'model_state_dict': model.state_dict(),    'optimizer_state_dict': optimizer.state_dict(),}, os.path.join("/content/drive/MyDrive/MADE/Project/RACNN_models_CNN/", f"val.tgz"))

training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.35932922985581


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.373347327502724


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.3602296193700045


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.430725706996871


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.35920175292558


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.460133265679568


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.359903858691989


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.401758327982784


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.3596568072600625


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.438837840568786


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.361060560235775


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.389079947828078


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.359389155373877


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.429263761084454


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.360649999553189


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.471410882628821


training...:   0%|          | 0/613 [00:00<?, ?it/s]

4.359567176265001


validating...:   0%|          | 0/154 [00:00<?, ?it/s]

4.431627965148757


training...:   0%|          | 0/613 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [44]:
model_state  = torch.load(os.path.join("/content/drive/MyDrive/MADE/Project/RACNN_models_CNN/", f"val.tgz"))
#model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device)
model.load_state_dict(model_state['model_state_dict'])
calculate_predictions(model, val_dataloader)

predicting...:   0%|          | 0/154 [00:00<?, ?it/s]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
       

In [43]:
calculate_predictions(model, val_dataloader)

predicting...:   0%|          | 0/154 [00:00<?, ?it/s]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
       

In [None]:
val_loss = evaluate(model, val_dataloader, criterion, device)
    # #break
print(val_loss)

In [None]:
print(model.tfe[0].weight)

In [None]:
print(model.tfe[0].weight)

In [None]:
#print(model.parameters)
#print(model.requires_grad_)
#print(model.tfe.parameters)
params = (model.parameters())
print(model.tfe[0].weight)

In [None]:
print(model.fc2.weight)

In [None]:
print(model.fc2.weight)

In [None]:
print(model.fc2.weight)

In [None]:
for param in model.parameters():
  print(param.requires_grad)
  print(param.grad)

In [47]:
accs_all = []
f1_all = []
for sub in range(32):
 
  #for fold in range(len(ind_tests)):
      # model_state  = torch.load(os.path.join("/content/drive/MyDrive/MADE/Project/RACNN_models/", f"val_{type_emotion}_{fold}.tgz"))
      # #model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device)
      # model.load_state_dict(model_state['model_state_dict'])
  val_dataset_sub = EmotionDataset(args.data_dir, 'val', ind_test, data[sub:sub + 1], labels[sub : sub + 1])
  val_dataloader_sub = DataLoader(val_dataset_sub, batch_size=args.batch_size, num_workers=1,
                              pin_memory=True, shuffle=False, drop_last=False)
    
  #print(val_datar_sub.cnt)
  f1, acc = calculate_predictions(model, val_dataloader_sub)
  accs_all.append(acc)
  f1_all.append(f1)
  # print(np.mean(accs_all))    
  # print(np.mean(f1_all))    
  # accs_result.append(np.mean(accs_all))
  # f1_result.append(np.mean(f1_all))
print(np.mean(accs_all))    
print(np.mean(f1_all))   

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.511727078891258
[[  0 229]
 [  0 240]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       229
        True       0.51      1.00      0.68       240

    accuracy                           0.51       469
   macro avg       0.26      0.50      0.34       469
weighted avg       0.26      0.51      0.35       469

(8, 40, 7680)
(8, 4)


  _warn_prf(average, modifier, msg_start, len(result))


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.5159914712153518
[[203  26]
 [201  39]]
              precision    recall  f1-score   support

       False       0.50      0.89      0.64       229
        True       0.60      0.16      0.26       240

    accuracy                           0.52       469
   macro avg       0.55      0.52      0.45       469
weighted avg       0.55      0.52      0.44       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6162046908315565
[[  0 180]
 [  0 289]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       180
        True       0.62      1.00      0.76       289

    accuracy                           0.62       469
   macro avg       0.31      0.50      0.38       469
weighted avg       0.38      0.62      0.47       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.8955223880597015
[[  0  49]
 [  0 420]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00        49
        True       0.90      1.00      0.94       420

    accuracy                           0.90       469
   macro avg       0.45      0.50      0.47       469
weighted avg       0.80      0.90      0.85       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.744136460554371
[[  0 120]
 [  0 349]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       120
        True       0.74      1.00      0.85       349

    accuracy                           0.74       469
   macro avg       0.37      0.50      0.43       469
weighted avg       0.55      0.74      0.63       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6417910447761194
[[  1 168]
 [  0 300]]
              precision    recall  f1-score   support

       False       1.00      0.01      0.01       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.82      0.50      0.40       469
weighted avg       0.77      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6353944562899787
[[  1 168]
 [  3 297]]
              precision    recall  f1-score   support

       False       0.25      0.01      0.01       169
        True       0.64      0.99      0.78       300

    accuracy                           0.64       469
   macro avg       0.44      0.50      0.39       469
weighted avg       0.50      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.511727078891258
[[  0 229]
 [  0 240]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       229
        True       0.51      1.00      0.68       240

    accuracy                           0.51       469
   macro avg       0.26      0.50      0.34       469
weighted avg       0.26      0.51      0.35       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.746268656716418
[[  1 119]
 [  0 349]]
              precision    recall  f1-score   support

       False       1.00      0.01      0.02       120
        True       0.75      1.00      0.85       349

    accuracy                           0.75       469
   macro avg       0.87      0.50      0.44       469
weighted avg       0.81      0.75      0.64       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.3752665245202559
[[123 106]
 [187  53]]
              precision    recall  f1-score   support

       False       0.40      0.54      0.46       229
        True       0.33      0.22      0.27       240

    accuracy                           0.38       469
   macro avg       0.37      0.38      0.36       469
weighted avg       0.36      0.38      0.36       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.8720682302771855
[[  0  60]
 [  0 409]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00        60
        True       0.87      1.00      0.93       409

    accuracy                           0.87       469
   macro avg       0.44      0.50      0.47       469
weighted avg       0.76      0.87      0.81       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.7164179104477612
[[  2 107]
 [ 26 334]]
              precision    recall  f1-score   support

       False       0.07      0.02      0.03       109
        True       0.76      0.93      0.83       360

    accuracy                           0.72       469
   macro avg       0.41      0.47      0.43       469
weighted avg       0.60      0.72      0.65       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6162046908315565
[[  0 180]
 [  0 289]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       180
        True       0.62      1.00      0.76       289

    accuracy                           0.62       469
   macro avg       0.31      0.50      0.38       469
weighted avg       0.38      0.62      0.47       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.744136460554371
[[  0 120]
 [  0 349]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       120
        True       0.74      1.00      0.85       349

    accuracy                           0.74       469
   macro avg       0.37      0.50      0.43       469
weighted avg       0.55      0.74      0.63       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.5266524520255863
[[ 42 187]
 [ 35 205]]
              precision    recall  f1-score   support

       False       0.55      0.18      0.27       229
        True       0.52      0.85      0.65       240

    accuracy                           0.53       469
   macro avg       0.53      0.52      0.46       469
weighted avg       0.53      0.53      0.47       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.6396588486140725
[[  0 169]
 [  0 300]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       169
        True       0.64      1.00      0.78       300

    accuracy                           0.64       469
   macro avg       0.32      0.50      0.39       469
weighted avg       0.41      0.64      0.50       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.8720682302771855
[[  0  60]
 [  0 409]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00        60
        True       0.87      1.00      0.93       409

    accuracy                           0.87       469
   macro avg       0.44      0.50      0.47       469
weighted avg       0.76      0.87      0.81       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

(8, 40, 7680)
(8, 4)


predicting...:   0%|          | 0/5 [00:00<?, ?it/s]

0.767590618336887
[[  0 109]
 [  0 360]]
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       109
        True       0.77      1.00      0.87       360

    accuracy                           0.77       469
   macro avg       0.38      0.50      0.43       469
weighted avg       0.59      0.77      0.67       469

0.6812366737739872
0.7812517810792325
