In [1]:
import os
import sys
from time import gmtime, strftime

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import numpy as np
import pandas as pd
import torch
import poutyne
from poutyne import Model,Experiment

In [3]:
from data.custom_data import filepath_dataframe
from data.selection import Selection,SelectionSet_1
from data.transformation import Transform_CnnLstmS,Transform_CnnS,Transform_Cnn
from data.torchData import DataLoadings,DataLoading

from data.custom_data import nucPaired_fpDataframe
from data.torchData import PairDataLoading,DataLoading
from training.contrastive_pretraining import Contrastive_PreTraining
from training.finetuning import FineTuneCNN

import models

In [4]:
#####################################################################################################################

# random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# gpu setting
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(DEVICE)
device = DEVICE

## data directory
data_dir  = 'E:\\external_data\\opera_csi\\Session_2\\experiment_data\\experiment_data\\exp_7_amp_spec_only\\npy_format'
readtype = 'npy'
splitchar = '\\'
fpath = '.'

# data selection
data_selection = SelectionSet_1() # Selection(split='loov',test_sub='Three',val_sub=0.1, nuc='NUC1',room=1,sample_per_class=None)
dataselection_name = 'SelectionSet1'

# data loading
transform = None
batch_size = 64
num_workers = 0

# training
optimizer_builder = torch.optim.Adam
lr = 0.0005
pretrain_epochs = 1
finetune_epochs = 300

# model
builder = lambda : (models.baseline.ConvNet1D(strides=[16,8,4]), 6144) # strides=[16,8,4]
hidden_layer = 128
network_name = 'ConvNet1D'

# Experiment Name
comment = '3LayerDefault'


# auto
exp_name = f'{network_name}_Supervised_{dataselection_name}_Comment-{comment}'
record_dir = os.path.join(fpath,'records')
model_dir = os.path.join(fpath,'saved_model')
model_fname = os.path.join(model_dir,f'Encoder___{exp_name}')
record_fname =  os.path.join(record_dir,f'{exp_name}.csv')
print('Experiment Name: ',exp_name)
print('Cuda Availability: ',torch.cuda.is_available())

Experiment Name:  ConvNet1D_Supervised_SelectionSet1_Comment-3LayerDefault
Cuda Availability:  True


In [5]:
# -----------------------------------Main-------------------------------------------

# if __name__ == '__main__':

# data preparation
df = filepath_dataframe(data_dir,splitchar)
df_train,df_val,df_test = data_selection(df)
df_train = pd.concat([df_train,df_val])

In [6]:
##### FINE-TUNING #####

# data loading
data_loading = DataLoading(transform=transform,batch_size=batch_size,readtype=readtype,
                           num_workers=num_workers,drop_last=True)
test_loading = DataLoading(transform=transform,batch_size=len(df_test),readtype=readtype,
                           num_workers=num_workers,drop_last=True)

train_loader = data_loading(df_train)
test_loader  = test_loading(df_test)

# load and create model
model = FineTuneCNN(model_path=None,
                    encoder_builder=builder,
                    n_classes=df.activity.nunique(),
                    hidden_layer=hidden_layer)

In [7]:
from torchsummary import summary

summary(model,(70,6400),device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 128, 397]         573,568
       BatchNorm1d-2             [-1, 128, 397]               0
            Conv1d-3              [-1, 256, 48]         524,544
       BatchNorm1d-4              [-1, 256, 48]               0
            Conv1d-5              [-1, 512, 12]         524,800
       BatchNorm1d-6              [-1, 512, 12]               0
           Flatten-7                 [-1, 6144]               0
         ConvNet1D-8                 [-1, 6144]               0
            Linear-9                  [-1, 128]         786,560
        LeakyReLU-10                  [-1, 128]               0
          Dropout-11                  [-1, 128]               0
           Linear-12                    [-1, 6]             774
       Classifier-13                    [-1, 6]               0
Total params: 2,410,246
Trainable param

In [None]:
# train with poutyne
mdl = Model(model,'adam','cross_entropy',
            batch_metrics=['accuracy'],
            epoch_metrics=[poutyne.F1('micro'),poutyne.F1('macro')]).to(device)
history = mdl.fit_generator(train_generator=train_loader,valid_generator=test_loader,epochs=finetune_epochs)

[35mEpoch: [36m1/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m13.41s [35mloss:[94m 5.362283[35m acc:[94m 27.698864[35m fscore_micro:[94m 0.276989[35m fscore_macro:[94m 0.125666[35m val_loss:[94m 2.019791[35m val_acc:[94m 25.757576[35m val_fscore_micro:[94m 0.257576[35m val_fscore_macro:[94m 0.182119[0m
[35mEpoch: [36m2/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.88s [35mloss:[94m 1.825870[35m acc:[94m 27.911932[35m fscore_micro:[94m 0.279119[35m fscore_macro:[94m 0.083249[35m val_loss:[94m 1.874961[35m val_acc:[94m 17.424242[35m val_fscore_micro:[94m 0.174242[35m val_fscore_macro:[94m 0.070608[0m
[35mEpoch: [36m3/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m14.05s [35mloss:[94m 1.697204[35m acc:[94m 31.321023[35m fscore_micro:[94m 0.313210[35m fscore_macro:[94m 0.122568[35m val_loss:[94m 1.866872[35m val_acc:[94m 17.424242[

[35mEpoch: [36m24/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m13.86s [35mloss:[94m 1.028631[35m acc:[94m 57.457386[35m fscore_micro:[94m 0.574574[35m fscore_macro:[94m 0.580397[35m val_loss:[94m 1.635183[35m val_acc:[94m 46.212120[35m val_fscore_micro:[94m 0.462121[35m val_fscore_macro:[94m 0.454124[0m
[35mEpoch: [36m25/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m18.16s [35mloss:[94m 1.054567[35m acc:[94m 55.894886[35m fscore_micro:[94m 0.558949[35m fscore_macro:[94m 0.554625[35m val_loss:[94m 1.901995[35m val_acc:[94m 44.696968[35m val_fscore_micro:[94m 0.446970[35m val_fscore_macro:[94m 0.465119[0m
[35mEpoch: [36m26/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.85s [35mloss:[94m 1.110745[35m acc:[94m 54.190341[35m fscore_micro:[94m 0.541903[35m fscore_macro:[94m 0.521133[35m val_loss:[94m 1.580997[35m val_acc:[94m 44.69696

[35mEpoch: [36m47/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.70s [35mloss:[94m 0.423865[35m acc:[94m 84.517045[35m fscore_micro:[94m 0.845170[35m fscore_macro:[94m 0.845812[35m val_loss:[94m 2.799918[35m val_acc:[94m 34.090908[35m val_fscore_micro:[94m 0.340909[35m val_fscore_macro:[94m 0.331275[0m
[35mEpoch: [36m48/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m13.65s [35mloss:[94m 0.495933[35m acc:[94m 78.409091[35m fscore_micro:[94m 0.784091[35m fscore_macro:[94m 0.778150[35m val_loss:[94m 3.288219[35m val_acc:[94m 30.303032[35m val_fscore_micro:[94m 0.303030[35m val_fscore_macro:[94m 0.288576[0m
[35mEpoch: [36m49/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m15.29s [35mloss:[94m 0.431835[35m acc:[94m 83.025568[35m fscore_micro:[94m 0.830256[35m fscore_macro:[94m 0.839309[35m val_loss:[94m 2.921170[35m val_acc:[94m 37.87878

[35mEpoch: [36m70/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.89s [35mloss:[94m 0.187044[35m acc:[94m 93.536932[35m fscore_micro:[94m 0.935369[35m fscore_macro:[94m 0.937039[35m val_loss:[94m 3.402895[35m val_acc:[94m 34.848484[35m val_fscore_micro:[94m 0.348485[35m val_fscore_macro:[94m 0.334606[0m
[35mEpoch: [36m71/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.99s [35mloss:[94m 0.190732[35m acc:[94m 93.963068[35m fscore_micro:[94m 0.939631[35m fscore_macro:[94m 0.934849[35m val_loss:[94m 4.937306[35m val_acc:[94m 31.060606[35m val_fscore_micro:[94m 0.310606[35m val_fscore_macro:[94m 0.299074[0m
[35mEpoch: [36m72/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.91s [35mloss:[94m 0.160890[35m acc:[94m 95.028409[35m fscore_micro:[94m 0.950284[35m fscore_macro:[94m 0.954318[35m val_loss:[94m 4.958186[35m val_acc:[94m 29.54545

[35mEpoch: [36m93/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.94s [35mloss:[94m 0.061373[35m acc:[94m 98.153409[35m fscore_micro:[94m 0.981534[35m fscore_macro:[94m 0.979182[35m val_loss:[94m 5.285191[35m val_acc:[94m 31.060606[35m val_fscore_micro:[94m 0.310606[35m val_fscore_macro:[94m 0.290889[0m
[35mEpoch: [36m94/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.92s [35mloss:[94m 0.058112[35m acc:[94m 98.579545[35m fscore_micro:[94m 0.985795[35m fscore_macro:[94m 0.984720[35m val_loss:[94m 5.774714[35m val_acc:[94m 29.545454[35m val_fscore_micro:[94m 0.295455[35m val_fscore_macro:[94m 0.276617[0m
[35mEpoch: [36m95/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.06s [35mloss:[94m 0.064050[35m acc:[94m 98.011364[35m fscore_micro:[94m 0.980114[35m fscore_macro:[94m 0.976355[35m val_loss:[94m 6.016749[35m val_acc:[94m 32.57575

[35mEpoch: [36m116/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.91s [35mloss:[94m 0.408918[35m acc:[94m 84.801136[35m fscore_micro:[94m 0.848011[35m fscore_macro:[94m 0.828866[35m val_loss:[94m 2.833299[35m val_acc:[94m 29.545454[35m val_fscore_micro:[94m 0.295455[35m val_fscore_macro:[94m 0.291718[0m
[35mEpoch: [36m117/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.07s [35mloss:[94m 0.384685[35m acc:[94m 86.292614[35m fscore_micro:[94m 0.862926[35m fscore_macro:[94m 0.840907[35m val_loss:[94m 4.048282[35m val_acc:[94m 33.333336[35m val_fscore_micro:[94m 0.333333[35m val_fscore_macro:[94m 0.300941[0m
[35mEpoch: [36m118/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.63s [35mloss:[94m 0.312467[35m acc:[94m 88.920455[35m fscore_micro:[94m 0.889205[35m fscore_macro:[94m 0.890229[35m val_loss:[94m 2.757729[35m val_acc:[94m 31.06

[35mEpoch: [36m139/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.93s [35mloss:[94m 0.038825[35m acc:[94m 99.218750[35m fscore_micro:[94m 0.992188[35m fscore_macro:[94m 0.992317[35m val_loss:[94m 6.182098[35m val_acc:[94m 31.060606[35m val_fscore_micro:[94m 0.310606[35m val_fscore_macro:[94m 0.282981[0m
[35mEpoch: [36m140/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.88s [35mloss:[94m 0.040334[35m acc:[94m 98.934659[35m fscore_micro:[94m 0.989347[35m fscore_macro:[94m 0.990597[35m val_loss:[94m 6.046222[35m val_acc:[94m 29.545454[35m val_fscore_micro:[94m 0.295455[35m val_fscore_macro:[94m 0.269882[0m
[35mEpoch: [36m141/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m13.16s [35mloss:[94m 0.035392[35m acc:[94m 99.289773[35m fscore_micro:[94m 0.992898[35m fscore_macro:[94m 0.993117[35m val_loss:[94m 5.711771[35m val_acc:[94m 30.30

[35mEpoch: [36m162/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m14.40s [35mloss:[94m 0.012858[35m acc:[94m 99.786932[35m fscore_micro:[94m 0.997869[35m fscore_macro:[94m 0.997355[35m val_loss:[94m 5.829240[35m val_acc:[94m 33.333336[35m val_fscore_micro:[94m 0.333333[35m val_fscore_macro:[94m 0.301497[0m
[35mEpoch: [36m163/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m13.49s [35mloss:[94m 0.013446[35m acc:[94m 99.857955[35m fscore_micro:[94m 0.998580[35m fscore_macro:[94m 0.999222[35m val_loss:[94m 6.021155[35m val_acc:[94m 31.818182[35m val_fscore_micro:[94m 0.318182[35m val_fscore_macro:[94m 0.289277[0m
[35mEpoch: [36m164/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.10s [35mloss:[94m 0.016058[35m acc:[94m 99.715909[35m fscore_micro:[94m 0.997159[35m fscore_macro:[94m 0.997553[35m val_loss:[94m 6.082862[35m val_acc:[94m 34.09

[35mEpoch: [36m185/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m11.95s [35mloss:[94m 0.021325[35m acc:[94m 99.502841[35m fscore_micro:[94m 0.995028[35m fscore_macro:[94m 0.995484[35m val_loss:[94m 5.917440[35m val_acc:[94m 28.030304[35m val_fscore_micro:[94m 0.280303[35m val_fscore_macro:[94m 0.260097[0m
[35mEpoch: [36m186/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.21s [35mloss:[94m 0.030133[35m acc:[94m 99.218750[35m fscore_micro:[94m 0.992188[35m fscore_macro:[94m 0.991183[35m val_loss:[94m 6.224096[35m val_acc:[94m 37.121212[35m val_fscore_micro:[94m 0.371212[35m val_fscore_macro:[94m 0.313702[0m
[35mEpoch: [36m187/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.79s [35mloss:[94m 0.057211[35m acc:[94m 98.224432[35m fscore_micro:[94m 0.982244[35m fscore_macro:[94m 0.985548[35m val_loss:[94m 8.567672[35m val_acc:[94m 25.75

[35mEpoch: [36m208/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m14.44s [35mloss:[94m 0.050025[35m acc:[94m 98.508523[35m fscore_micro:[94m 0.985085[35m fscore_macro:[94m 0.981339[35m val_loss:[94m 6.164179[35m val_acc:[94m 28.030304[35m val_fscore_micro:[94m 0.280303[35m val_fscore_macro:[94m 0.261578[0m
[35mEpoch: [36m209/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m14.54s [35mloss:[94m 0.042719[35m acc:[94m 98.721591[35m fscore_micro:[94m 0.987216[35m fscore_macro:[94m 0.987024[35m val_loss:[94m 6.495796[35m val_acc:[94m 30.303032[35m val_fscore_micro:[94m 0.303030[35m val_fscore_macro:[94m 0.287198[0m
[35mEpoch: [36m210/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.07s [35mloss:[94m 0.035947[35m acc:[94m 98.863636[35m fscore_micro:[94m 0.988636[35m fscore_macro:[94m 0.987460[35m val_loss:[94m 6.258346[35m val_acc:[94m 31.06

[35mEpoch: [36m231/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.68s [35mloss:[94m 0.020127[35m acc:[94m 99.431818[35m fscore_micro:[94m 0.994318[35m fscore_macro:[94m 0.994207[35m val_loss:[94m 7.437157[35m val_acc:[94m 30.303032[35m val_fscore_micro:[94m 0.303030[35m val_fscore_macro:[94m 0.288616[0m
[35mEpoch: [36m232/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.01s [35mloss:[94m 0.014552[35m acc:[94m 99.573864[35m fscore_micro:[94m 0.995739[35m fscore_macro:[94m 0.994657[35m val_loss:[94m 7.621285[35m val_acc:[94m 28.030304[35m val_fscore_micro:[94m 0.280303[35m val_fscore_macro:[94m 0.260645[0m
[35mEpoch: [36m233/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.01s [35mloss:[94m 0.024090[35m acc:[94m 99.715909[35m fscore_micro:[94m 0.997159[35m fscore_macro:[94m 0.997837[35m val_loss:[94m 7.268957[35m val_acc:[94m 31.81

[35mEpoch: [36m254/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.87s [35mloss:[94m 0.012756[35m acc:[94m 99.573864[35m fscore_micro:[94m 0.995739[35m fscore_macro:[94m 0.994411[35m val_loss:[94m 8.417929[35m val_acc:[94m 31.818182[35m val_fscore_micro:[94m 0.318182[35m val_fscore_macro:[94m 0.289500[0m
[35mEpoch: [36m255/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m13.82s [35mloss:[94m 0.039493[35m acc:[94m 98.863636[35m fscore_micro:[94m 0.988636[35m fscore_macro:[94m 0.984803[35m val_loss:[94m 6.660084[35m val_acc:[94m 25.000000[35m val_fscore_micro:[94m 0.250000[35m val_fscore_macro:[94m 0.251397[0m
[35mEpoch: [36m256/300 [35mStep: [36m22/22 [35m100.00% |[35m█████████████████████████[35m|[32m12.55s [35mloss:[94m 0.167953[35m acc:[94m 95.099432[35m fscore_micro:[94m 0.950994[35m fscore_macro:[94m 0.934337[35m val_loss:[94m 8.850465[35m val_acc:[94m 28.78

In [None]:
pd.DataFrame(history).to_csv(record_fname)