<a href="https://colab.research.google.com/github/castlechoi/studyingDL/blob/main/TimeSeries/TF_C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np

import matplotlib.pyplot as plt

In [None]:
class Config():
  def __init__(self):
    # data feature
    self.target_num_classes = 2
    self.feature_len = 176

    self.subset = True

    # optimizer parameter
    self.batch_size = 64
    self.num_epochs = 40

    # augmentation parameter
    self.sigma

In [None]:
class Encoder(nn.Module):
  def __init__(self, config):
    super(ResnetBackbone, self).__init__()

    self.layer1 = nn.Sequential(
      nn.Conv1d(1, 32, kernel_size = 8, stride = 8),
      nn.BatchNorm1d(),
      nn.ReLU(),
      nn.MaxPool1d(kernel_size = 2, stride = 2)
    )
    self.layer2 = nn.Sequential(
        nn.Conv1d(32, 64, kernel_size = 8, stride = 1),
        nn.BatchNorm1d(),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size = 2, stride = 2)
    )
    self.layer3 = nn.Sequential(
        nn.Conv1d(64,128, kernel_size = 8, stride = 1),
        nn.BatchNorm1d(),
        nn.ReLU(),
    )

    assert config.feature_len % 8 == 0
    self.feature_len = config.feature_len / 8 - 14

    self.fc = nn.Sequential(
        nn.Linear(self.feature_len,256),
        nn.BatchNorm1d(256)
        nn.ReLU(),
        nn.Linear(256.128),
        )

  def forward(self, x): 
    #  batch_size * channel * feature_len 
    #  64 * 1 * 200
    out = self.layer1(x)    # 64 * 32 * 25
    out = self.layer2(out)  # 64 * 64 * (18)
    out = self.layer3(out)  # 64 * 128 * (8)
    out = out.view(out.size(0),-1)  # flatten
    out = self.fc(out)
    return out

In [None]:
class TFC(nn.Module):
  def __init__(self,config):
    super(TFC, self).__init__()
    self.encoder_t = Encoder(config)
    self.encoder_f = Encoder(config)

    self.projector_t = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128)
            )
    self.projector_f = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128)
            )

  def forward(self,x_in_t, x_in_f):
    h_t = self.encoder_t(x_in_t)
    z_t = self.projector_t(h_t)

    h_f = self.encoder_f(x_in_f)
    z_f = self.projector_f(h_f)

    return h_t, z_t, h_f, z_f

In [None]:
class TaskClassifier(nn.module):
  def __init__(self,config):
    super(TaskClassifier, self).__init__()
    self.fc1 = nn.Linear(256,64)
    self.fc2 = nn.Linear(64,config.target_num_classes)

    self.sigmoid = nn.Softmax(dim = 1)
  
  def forward(self, x_t, x_f):
    out = self.fc1(x)
    # flatten
    out = out.view(out.size(0), -1)
    out = self.sigmoid(out)
    out = self.fc2(out)
    return out

In [None]:
class NTXentLoss(nn.Module):
  def __init__(self):
    super(NTXentLoss, self).__init__()

  def forward(self,x):
    return x

In [None]:
def dataAugmentation(self, data, config, domain = "time"):
  if domain == "time":
    # jittering
    return x + np.random_normal(loc = 0, scale = config.sigma, size = data.shape)
  else:
    # remove freq
    mask = torch.cuda.FloatTensor(data.shape).uniform_() > 0
    mask = mask.to(data.device)
    return data * mask

In [None]:
class LoadDataset(Dataset):
  def __init__(self, dataset, config, training_mode):
    super(LoadDataset, self).__init__()

    self.x_train = dataset["samples"]
    self.y_train = dataset["labels"]

    # shuffle the data
    data = list(zip(x_train, y_train))
    np.random.shuffle(data)
    x_train, y_train = zip(*data)
    # np -> torch
    x_train, y_train = torch.stack(list(x_train), dim = 0), torch.stack(list(y_train), dim = 0)

    if config.subset == True:
      subset_size = 64 * 10
      x_train = x_train[:subset_size]
      y_train = y_train[:subset_size]

    self.x_data = x_train [:,:1,:176] # Epilepsy length 178
    self.y_data = y_train

    self.x_data_f = ff.fft(self.x_data).abs()
    self.len = x_train.shape[0]

    if training_mode == "pre_train":
      self.aug1 = dataAugmentation(self.x_data, config,"time")
      self.aug1_f = dataAugmentation(self.x_data_f, config,"freq")


  def __getitem__(self, index):
    if self.training_mode == "pre_train":
      return self.x_data[index], self.aug1[index], self.x_data_f[index], self.aug1_f[index], self.y_data[index]
    else:
      return self.x_data[index], self.x_data[index], self.x_data_f[index], self.x_data_f[index], self.y_data[index]
  def __len__(self):
    return self.len

In [None]:
def load_dataset(config):
  pretrain_dataset = torch.load(os.path.join('./SleepEEG', "train.pt"))
  fintune_train_dataset = torch.load(os.path.join('./Epilepsy',"train.pt"))
  finetune_valid_dataset = torch.load(os.path.join('./Epilepsy',"valid.pt"))
  finetune_test_dataset = torch.load(os.path.join('./Epilepsy',"test.pt"))
  
  pretrain_dataset = LoadDataset(pretrain_dataset, config, "pre_train")
  finetune_train_dataset = LoadDataset(fintune_train_dataset, config, "finetune")
  finetune_valid_dataset = LoadDataset(finetune_valid_dataset, config, "finetune")
  finetune_test_dataset = LoadDataset(finetune_test_dataset, config, "finetune")

  pret_loader = torch.utils.data.DataLoader(dataset=pretrain_dataset, batch_size=64, shuffle=True, drop_last=True,num_workers=0)
  fine_train_loader = torch.utils.data.DataLoader(dataset=finetune_train_dataset, batch_size=64, shuffle=True, drop_last=True,num_workers=0)
  fine_valid_loader = torch.utils.data.DataLoader(dataset=finetune_valid_dataset, batch_size=64, shuffle=True, drop_last=True,num_workers=0)
  fine_test_loader = torch.utils.data.DataLoader(dataset=finetune_test_dataset, batch_size=64, shuffle=True, drop_last=True,num_workers=0)

  return pret_loader,fine_train_loader,fine_valid_loader,fine_test_loader

In [None]:
# main
config = Config()

model = TFC(config)
classifier = TaskClassifier(config)

model_optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4, weight_decay = 5e-4)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr = 3e-4, weight_decay = 5e-4)

pret_criterion = NTXentLoss()
classification_criterion = nn.CrossEntropyLoss()

# load data
pret_dl, fine_train_dl, fine_val_dl, fine_test_dl = load_dataset(config)

In [None]:
# Pre-Train
for epoch in range(config.num_epochs):
  pret_loss = 0
  for x_t, x_t_aug, x_f, x_f_aug, _ in pret_dl:
    model_optimizer.zero_grad()
    h_t, z_t, h_f, z_f =  model(x_t, x_f)
    h_t_aug, z_t_aug, h_f_aug, z_f_aug = model(x_t_aug, x_f_aug)

    loss_t = pret_criterion(h_t, h_t_aug)
    loss_f = pret_criterion(h_f, h_f_aug)
    # positive pair in TF embedding space
    loss_c_p = pret_criterion(z_t, z_f)
    # negative pairs in TF embedding space
    loss_c_n1 = pret_criterion(z_t, z_f_aug)
    loss_c_n2 = pret_criterion(z_f, z_t_aug)
    loss_c_n3 = pret_criterion(z_t_aug, z_f_aug)

    loss_c = (loss_c_p - loss_c_n1 + 1) + (loss_c_p - loss_c_n2 + 1) + (loss_c_p - loss_c_n3 + 1)

    lam = 0.5
    loss_tfc = lam * (loss_t + loss_f) + (1-lam) * loss_c

    loss_tfc.backward()
    model_optimizer.step()
    pret_loss += loss_tfc.item()

  print(f'Epoch {epoch+1} Loss : {pret_loss / epoch} ')

# FineTuning
valid_loss_list = []
global best_model, best_classifier
for epoch in range(config.num_epochs):
  finetune_loss = 0
  for x_t, _, x_f, _, label in fine_train_dl:
    model_optimizer.zero_grad()
    classifier_optimizer.zero_grad()
    h_t, z_t, h_f, z_f =  model(x_t, x_f)
    pred_class = classifier(z_t,z_f)

    loss_t = pret_criterion(h_t, h_t_aug)
    loss_f = pret_criterion(h_f, h_f_aug)
    # positive pair in TF embedding space
    loss_c_p = pret_criterion(z_t, z_f)
    # negative pairs in TF embedding space
    loss_c_n1 = pret_criterion(z_t, z_f_aug)
    loss_c_n2 = pret_criterion(z_f, z_t_aug)
    loss_c_n3 = pret_criterion(z_t_aug, z_f_aug)

    loss_c = (loss_c_p - loss_c_n1 + 1) + (loss_c_p - loss_c_n2 + 1) + (loss_c_p - loss_c_n3 + 1)

    lam = 0.5
    loss_tfc = lam * (loss_t + loss_f) + (1-lam) * loss_c
    loss_p = classification_criterion(pred_class, label)

    loss = loss_p + loss_tfc

    loss.backward()
    model_optimizer.step()
    classifier_optimizer.zero_grad()
    finetune_loss += loss.item()
  
  print(f'Finetune Epoch {epoch+1} Loss : {finetune_loss / epoch} ')
  # valid
  valid_loss = 0
  test_loss = 0
  with torch.no_grad():
    for x_t, _, x_f, _, label in fine_valid_dl:
      h_t, z_t, h_f, z_f =  model(x_t, x_f)
      pred_class = classifier(z_t,z_f)

      loss_p = classification_criterion(pred_class, label)

      valid_loss += loss_p.item()
    print(f'Valid Epoch {epoch+1} Loss : {valid_loss / epoch} ')
  valid_loss_list.append(valid_loss)

  if(min(valid_loss_list) > valid_loss) or epoch == 0:
    best_model = model.state_dict()
    best_classifier = classifier.state_dict()
  else:
    model.load_state_dict(best_model)
    classifier.load_state_dict(best_classifier)
    
  # test
  with torch.no_grad():
    for x_t, _, x_f, _, label in fine_test_dl:
      h_t, z_t, h_f, z_f =  model(x_t, x_f)
      pred_class = classifier(z_t,z_f)

      loss_p = classification_criterion(pred_class, label)

      valid_loss += loss_p.item()
    print(f'Valid Epoch {epoch+1} Loss : {valid_loss / epoch} ')