In [None]:
import numpy as np
import pandas as pd
import seaborn as sns  # for heatmaps
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from collections import Counter
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchsummary import summary
import torchvision

# Contrastive_learning_for_unseen_environments

In [None]:
# random seed
np.random.seed(1024)
torch.manual_seed(1024)

# gpu setting
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = DEVICE
torch.cuda.set_device(DEVICE)

In [None]:
TRAINMODE = 'simclr' # 'normal' #
PAIRING =  'nuc2' # 'pwr' #
NETWORK = 'shallow' # 'alexnet' # resnet18
TEMPERATURE = 'L'

In [None]:
fp2 = './data/experiment_data/exp_2/spectrogram' 
fp3 = './data/experiment_data/exp_3/spectrogram' 
fp4 = './data/experiment_data/exp_4/spectrogram_multi'
OUT_PATH = '.'

JOINT = 'first'
SKIP = ['standff'] # ,'waving'
BATCH_SIZE = 64
NUM_WORKERS = 0
PRETRAIN_EPOCHS = 1200
LAB_FINETUNE_EPOCHS = 200
FIELD_FINETUNE_EPOCHS = 200
REGULARIZE = None

## 0. Setup 

### a. data

In [None]:
from data.spectrogram import import_data, import_pair_data
from data import *

### b. model setup

In [None]:
from models.baseline import *
from models.cnn import *
from models.self_supervised import *
from models.utils import *
from models import *

In [None]:
from losses import NT_Xent
from contrastive_learning import pretrain
from train import record_log,save,evaluation,train,load_checkpoint

In [None]:
def create_encoder(network,pairing):
    if network == "shallow":
        if pairing == 'csi':
            encoder = create_baseline_encoder(scale_factor=1)
            outsize = 960
        elif pairing == 'nuc2':
            encoder = create_baseline_encoder(scale_factor=1)
            outsize = 960
        elif pairing == 'pwr':
            encoder = create_baseline_encoder(scale_factor=3)
            outsize = 1152
        else: 
            raise ValueError("pairing must be in {'csi','nuc2','pwr'}")
    elif network == "alexnet":
        if pairing == 'csi':
            encoder,outsize = create_alexnet((1,4),scale_factor=1)
        elif pairing == 'nuc2':
            encoder,outsize = create_alexnet((1,4),scale_factor=1)
        elif pairing == 'pwr':
            encoder,outsize = create_alexnet((4,1),scale_factor=2)
        else: 
            raise ValueError("pairing must be in {'csi','nuc2','pwr'}")
    elif network == "resnet":
        if pairing == 'csi':
            encoder,outsize = create_resnet18((2,2))
        if pairing == 'nuc2':
            encoder,outsize = create_resnet18((2,2))
        elif pairing == 'pwr':
            encoder,outsize = create_resnet18((2,2))
        else: 
            raise ValueError("pairing must be in {'csi','nuc2','pwr'}")
    else:
        raise ValueError("network must be in {'shallow','alexnet','resnet'}")
    return encoder, outsize

In [None]:
def record_log(record_outpath,exp_name,phase,record='None',cmtx='None',cls='None',loss_rec=True,acc_rec=False):
    prefix = record_outpath+'/'+exp_name+'_Phase_'+phase
    if type(record) != str:
        if loss_rec:
            pd.DataFrame(record['train'],columns=['train_loss']).to_csv(prefix+'_loss.csv')
        if acc_rec: 
            pd.DataFrame(record['validation'],columns=['validation_accuracy']).to_csv(prefix+'_accuracy.csv')
    if type(cmtx) != str:
        cmtx.to_csv(prefix+'_cmtx.csv')
    if type(cls) != str:
        cls.to_csv(prefix+'_report.csv')
    return

In [None]:
def save_model(model_outpath,exp_name,phase,model):
    model_fp = model_outpath+'/'+exp_name+'_Phase_'+phase
    torch.save(encoder.state_dict(), model_fp)
    return model_fp

def load_encoder(network, model_fp):
    encoder, outsize = create_encoder(network,'csi')
    encoder.load_state_dict(torch.load(model_fp))
    return encoder, outsize

def load_model(network, freeze, lb, model_fp):
    encoder, outsize = create_encoder(network,'csi')
    model = add_classifier(encoder,in_size=outsize,out_size=(lb.classes_),freeze=freeze)
    model.load_state_dict(torch.load(model_fp))
    return model
    


# Main

In [None]:
trainmode = TRAINMODE
pairing = PAIRING
network = NETWORK
temperature = TEMPERATURE
joint = JOINT
activities = SKIP
batch_size = BATCH_SIZE
num_workers = NUM_WORKERS
pretrain_epochs = PRETRAIN_EPOCHS 
lab_finetune_epochs = LAB_FINETUNE_EPOCHS 
field_finetune_epochs = FIELD_FINETUNE_EPOCHS
regularize = REGULARIZE 
t = TEMPERATURE


temperature_dict = {'L':0.1,'M':0.5,'H':1}    
temperature = temperature_dict[t]

if TRAINMODE == 'normal':
    pairing = ''
    t = ''
    freeze=False
    
else:
    freeze=True


exp_name = f'Trainmode-{trainmode}-{t}_Network-{network}_Data-exp4csi{pairing}'
model_outpath  = OUT_PATH+'/'+'laboratory'+'/'+'saved_models'
record_outpath = OUT_PATH+'/'+'laboratory'+'/'+'records'

print(exp_name)

In [None]:
####################################################### Main #############################################################


X1,X2,y = import_pair_data(fp4,modal=['csi',pairing])
X1_train, X1_test, X2_train, X2_test, y_train, y_test = split_datasets(X1,X2,y)
X_train, X_test, y_train, y_test = select_train_test_dataset(X1_train, X1_test, X2_train, X2_test, y_train, y_test, joint)
X_train, X_test, y_train, y_test, lb = filtering_activities_and_label_encoding(X_train, X_test, y_train, y_test, activities)

In [None]:
################################ Lab-Pretrain-phase ################################

### data
pretrain_loader = create_dataloader(X1_train,X2_train,batch_size=batch_size,num_workers=num_workers)
print('X1_train: ',X1_train.shape,'\tX2_train: ',X2_train.shape)

### network
encoder, outsize = create_encoder(network,'csi')
encoder2, outsize2 = create_encoder(network,pairing)
simclr = add_SimCLR_multi(enc1=encoder,enc2=encoder2,out_size1=outsize,out_size2=outsize2)

### pretraining
phase = 'pretrain'
if trainmode == 'simclr':
    criterion = NT_Xent(batch_size, temperature, world_size=1)
    optimizer = torch.optim.SGD(list(simclr.parameters()), lr=0.0005)
    simclr, record = pretrain(model=simclr,
                              train_loader=pretrain_loader,
                              criterion=criterion,
                              optimizer=optimizer,
                              end=pretrain_epochs,
                              device=device)

# record and save
record_log(record_outpath,exp_name,phase,record=record)
encoder_fp = save_model(model_outpath,exp_name,phase,simclr.encoder)
del simclr,encoder,encoder2
torch.cuda.empty_cache()

################################ Lab-Finetuning-phase ################################


inital = {'lab':True,'field':True}
for sampling in ['weight',1,5,10]:
    
    print('sampling: ',sampling)
    
    ### Sampling data ###
    lab_finetune_loader, lab_validatn_loader, class_weight = combine1(X_train, X_test, y_train, y_test, 
                                                              sampling, lb, batch_size, num_workers)
    print("class: ",lb.classes_)
    print("class_size: ",1-class_weight)
    
    ### model 
    encoder, outsize = load_encoder(network, encoder_fp)
    model = add_classifier(encoder,in_size=outsize,out_size=len(lb.classes_),freeze=freeze)
    
    # initialization
    phase = 'lab-initial'
    if inital['lab']:
        cmtx,cls = evaluation(model,lab_finetune_loader,label_encoder=lb)
        record_log(record_outpath,exp_name,phase,cmtx=cmtx,cls=cls)
        inital['lab'] = False
    
    # finetuning 
    phase = 'lab-finetune'
    criterion = nn.CrossEntropyLoss(weight=class_weight).to(device)
    optimizer = torch.optim.Adam(list(model.parameters()), lr=0.0005)
    model, record = train(model=model,
                          train_loader= lab_finetune_loader,
                          criterion=criterion,
                          optimizer=optimizer,
                          end= lab_finetune_epochs,
                          test_loader = lab_validatn_loader,
                          device = device,
                          regularize = regularize)
    
    # record and save
    cmtx,cls = evaluation(model,lab_validatn_loader,label_encoder=lb)
    record_log(record_outpath,exp_name,phase,record=record,cmtx=cmtx,cls=cls,acc_rec=True)
    
    if sampling != 'weight':
        
        del encoder,model,record,cmtx,cls,criterion,optimizer
        torch.cuda.empty_cache()

    elif sampling == 'weight':
        
        model_fp = save_model(model_outpath,exp_name,phase,model)
        del encoder,model,record,cmtx,cls,criterion,optimizer
        torch.cuda.empty_cache()
    

In [None]:
assert False

In [None]:
################################ Field-Finetuning-phase ################################
    
        for field_sampling in [1,5,10]:

            Xf,yf = import_data(fp2)
            field_finetune_loader, field_validatn_loader, lb = process_field_data(Xf,yf,num=field_sampling,lb=lb)
            
            
            model = load_model(network, freeze, lb, model_fp)
            
            # initialization
            phase = 'field-initial'
            if inital['field'] == True:
                cmtx,cls = evaluation(model,test_loader,label_encoder=lb)
                record_log(record_outpath,exp_name,phase,cmtx=cmtx,cls=cls)
                inital['field'] = False
                
            # finetuning 
            phase = 'field-finetune'
            criterion = nn.CrossEntropyLoss(weight=class_weight).to(device)
            optimizer = torch.optim.Adam(list(model.parameters()), lr=0.0005)
            model, record = train(model=model,
                                  train_loader= field_finetune_loader,
                                  criterion=criterion,
                                  optimizer=optimizer,
                                  end= field_finetune_epochs,
                                  test_loader = field_validatn_loader,
                                  device = device,
                                  regularize = regularize)
            # record and save
            evaluation(model,field_validatn_loader,label_encoder=lb)
            record_log(record_outpath,exp_name,phase,record=record,cmtx=cmtx,cls=cls,acc_rec=True)
            del encoder,model,record,cmtx,cls,criterion,optimizer
            torch.cuda.empty_cache()

        
        
        
    
    