In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
from einops import rearrange, reduce, repeat

from tqdm import tqdm

import time
import copy
from collections import defaultdict
import joblib
import gc
import os

from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
N = 2
HIDDEN_DIM = 256
NUM_HEAD = 8 
INNER_DIM = 512

PAD_IDX = 0
EOS_IDX = 3

In [None]:
ddir = './'

src_train_path = os.path.join(ddir,'src_train.pkl')
src_valid_path = os.path.join(ddir,'src_valid.pkl')
trg_train_path = os.path.join(ddir,'trg_train.pkl')
trg_valid_path = os.path.join(ddir,'trg_valid.pkl')

src_train_path2 = os.path.join(ddir,'src_train2.pkl')
src_valid_path2 = os.path.join(ddir,'src_valid2.pkl')
trg_train_path2 = os.path.join(ddir,'trg_train2.pkl')
trg_valid_path2 = os.path.join(ddir,'trg_valid2.pkl')

In [None]:
src_train = joblib.load(src_train_path)
src_valid = joblib.load(src_valid_path)
trg_train = joblib.load(trg_train_path)
trg_valid = joblib.load(trg_valid_path)

src_train2 = joblib.load(src_train_path2)
src_valid2 = joblib.load(src_valid_path2)
trg_train2 = joblib.load(trg_train_path2)
trg_valid2 = joblib.load(trg_valid_path2)

In [None]:
labels = list(set(trg_train))
labels

In [None]:
labels_dict = {}
for i in range(len(labels)) :
    labels_dict[labels[i]] = i
labels_dict

In [None]:
def multiLabelEncoder(labels_dict, target) :
    tmp = np.zeros((len(target), len(labels_dict)))
    for t in range(len(target)) :
        tmp[t][labels_dict[target[t]]] = 1
    return tmp

In [None]:
trg_valid2[:5]

In [None]:
test = multiLabelEncoder(labels_dict, trg_valid2)
test[:5]

In [None]:
trg_train = multiLabelEncoder(labels_dict, trg_train)
trg_valid = multiLabelEncoder(labels_dict, trg_valid)
trg_train2 = multiLabelEncoder(labels_dict, trg_train2)
trg_valid2 = multiLabelEncoder(labels_dict, trg_valid2)
trg_valid2[:5]

In [None]:
VOCAB_SIZE = 15*100*8
SEQ_LEN = 60*2

VOCAB_SIZE2 = 1108*8
SEQ_LEN2 = 4674*2

BATCH_SIZE = 256

In [None]:
class TrainDataset(Dataset):
    def __init__(self, src_data, trg_data):
        super().__init__()

        self.src_data = src_data
        self.trg_data = trg_data

    def __len__(self):
        return len(self.src_data)
        
    def __getitem__ (self, idx):
        src = self.src_data[idx]
        trg = self.trg_data[idx]

        return torch.Tensor(src).long(), torch.Tensor(trg)

train_dataset = TrainDataset(src_train2, trg_train2)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle= True, pin_memory=True)

In [None]:
class ValidDataset(Dataset):
    def __init__(self, src_data, trg_data):
        super().__init__()

        self.src_data = src_data
        self.trg_data = trg_data

    def __len__(self):
        return len(self.src_data)
        
    def __getitem__ (self, idx):
        src = self.src_data[idx]
        trg = self.trg_data[idx]

        return torch.Tensor(src).long(), torch.Tensor(trg)

valid_dataset = ValidDataset(src_valid2, trg_valid2)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle= False, pin_memory=True)

In [None]:
class FFN(nn.Module):
    def __init__ (self, hidden_dim, inner_dim):
        super().__init__()
 
        self.hidden_dim = hidden_dim

        self.inner_dim = inner_dim 

        self.fc1 = nn.Linear(hidden_dim, inner_dim)
        self.fc2 = nn.Linear(inner_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(0.1)
   
    def forward(self, input):
        output = input
        output = self.fc1(output)
        output2 = self.relu(output)
        output2 = self.dropout(output)
        output3 = self.fc2(output2)

        return output3

In [None]:
class Multiheadattention(nn.Module):
    def __init__(self, hidden_dim: int, num_head: int):
        super().__init__()

        # embedding_dim, d_model, 512 in paper
        self.hidden_dim = hidden_dim
        # 8 in paper
        self.num_head = num_head
        # head_dim, d_key, d_query, d_value, 64 in paper (= 512 / 8)
        self.head_dim = hidden_dim // num_head
        self.scale = torch.sqrt(torch.FloatTensor()).to(device)

        self.fcQ = nn.Linear(hidden_dim, hidden_dim)
        self.fcK = nn.Linear(hidden_dim, hidden_dim)
        self.fcV = nn.Linear(hidden_dim, hidden_dim)
        self.fcOut = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(0.1)


    def forward(self, srcQ, srcK, srcV, mask=None):

        ##### SCALED DOT PRODUCT ATTENTION ######

        Q = self.fcQ(srcQ)
        K = self.fcK(srcK)
        V = self.fcV(srcV)

        Q = rearrange(
            Q, 'bs seq_len (num_head head_dim) -> bs num_head seq_len head_dim', num_head=self.num_head)
        K_T = rearrange(
            K, 'bs seq_len (num_head head_dim) -> bs num_head head_dim seq_len', num_head=self.num_head)
        V = rearrange(
            V, 'bs seq_len (num_head head_dim) -> bs num_head seq_len head_dim', num_head=self.num_head)
        
        attention_energy = torch.matmul(Q, K_T)

        if mask is not None :
 
            attention_energy = torch.masked_fill(attention_energy, (mask == 0), -1e+4)
            
        attention_energy = torch.softmax(attention_energy, dim = -1)

        result = torch.matmul(self.dropout(attention_energy),V)

        ##### END OF SCALED DOT PRODUCT ATTENTION ######

        # CONCAT
        result = rearrange(result, 'bs num_head seq_len head_dim -> bs seq_len (num_head head_dim)')

        result = self.fcOut(result)

        return result

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_head, inner_dim):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim
        
        self.multiheadattention = Multiheadattention(hidden_dim, num_head)
        self.ffn = FFN(hidden_dim, inner_dim)
        self.layerNorm1 = nn.LayerNorm(hidden_dim)
        self.layerNorm2 = nn.LayerNorm(hidden_dim)


        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)


    def forward(self, input, mask = None):

        output = self.multiheadattention(srcQ= input, srcK = input, srcV = input, mask = mask)
        output = self.dropout1(output)
        output = input + output
        output = self.layerNorm1(output)

        output_ = self.ffn(output)
        output_ = self.dropout2(output_)
        output = output + output_
        output = self.layerNorm2(output)

        return output

In [None]:
class Encoder(nn.Module):
    def __init__ (self, N, hidden_dim, num_head, inner_dim,max_length=1000):
        super().__init__()

        # N : number of encoder layer repeated 
        self.N = N
        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim

        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=hidden_dim, padding_idx=-1)
        self.pos_embedding = nn.Embedding(max_length, hidden_dim)
        self.layers = nn.ModuleList([EncoderLayer(hidden_dim, num_head, inner_dim) for _ in range(N)])

        self.dropout = nn.Dropout(p=0.1)

    def forward(self, input):
        
        batch_size = input.shape[0]
        seq_len = input.shape[1]


        pos = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1).to(device)

        output = self.dropout(self.embedding(input) + self.pos_embedding(pos))

        # Dropout
        output = self.dropout(output)

        # N encoder layer
        for layer in self.layers:
            output = layer(output)


        return output

In [None]:
class Transformer(nn.Module):
    def __init__(self, N = 2, hidden_dim = 256, num_head = 64, inner_dim = 512):
        super().__init__()
        self.encoder = Encoder(N, hidden_dim, num_head, inner_dim)
        self.mlp = nn.Sequential(nn.Linear(256*SEQ_LEN, 64),
                                 nn.Linear(64,16),
                                 nn.GELU(),
                                 nn.Linear(16,8)
        )

    def forward(self, src):
        
        # src = torch.flatten(src, start_dim=0)
        
        output = self.encoder(src)
        output = torch.flatten(output, start_dim=1)
        pred = self.mlp(output)
    
        return pred

In [None]:
torch.Tensor(src_train).shape

In [None]:
model = Transformer(N, HIDDEN_DIM, NUM_HEAD, INNER_DIM).to(device)
model.eval()
x = torch.Tensor(src_train[:10]).long().to(device)
y = model(x)
y.shape

In [None]:
y[0, 0]

In [None]:
y[0, 1]

In [None]:
optimizer = torch.optim.Adam(params = model.parameters(), lr = 1e-4, weight_decay = 0)

criterion = nn.CrossEntropyLoss()

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):

    model.train()

    dataset_size = 0
    running_loss = 0
    running_accuracy = 0
    accuracy = 0

    bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for step, (src, trg) in bar:
        src = src.to(device)
        trg = trg.to(device)

        batch_size = src.shape[0]

        pred = model(src=src)
        
        loss = criterion(F.softmax(pred, dim=0), trg)

        loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)  
     
        optimizer.step()

        # zero the parameter gradients
        optimizer.zero_grad()

        # change learning rate by Scheduler
        if scheduler is not None:
            scheduler.step()

        running_loss += loss.item() * batch_size
        running_accuracy = torch.mean((torch.argmax(pred, axis=1) == torch.argmax(trg, axis=1)).type(torch.FloatTensor))

        accuracy += running_accuracy

        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size

        bar.set_postfix(
            Epoch=epoch, Train_Loss=epoch_loss, LR=optimizer.param_groups[0]["lr"], accuracy=accuracy / np.float(
                step+1)
        )

    accuracy /= len(dataloader)

    gc.collect()

    return epoch_loss, accuracy

In [None]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()

    dataset_size = 0
    running_loss = 0
    accuracy = 0

    bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for step, (src, trg) in bar:
        src = src.to(device)
        trg = trg.to(device)

        batch_size = src.shape[0]

        pred = model(src = src)
        loss = criterion(F.softmax(pred, dim=0), trg)

        running_loss += loss.item() * batch_size
        dataset_size += batch_size

     
        val_loss = running_loss / dataset_size
        running_accuracy = torch.mean((torch.argmax(pred, axis=1) == torch.argmax(trg, axis=1)).type(torch.FloatTensor))
  
        accuracy += running_accuracy

        bar.set_postfix(
            Epoch=epoch, Valid_Loss=val_loss, LR=optimizer.param_groups[0]["lr"], accuracy = accuracy / np.float(step + 1)
        )

    accuracy /= len(dataloader)

    gc.collect()

    return val_loss, accuracy

In [None]:
def run_training(
    model,
    optimizer,
    scheduler,
    device,
    num_epochs,
    metric_prefix="",
    file_prefix="",
    early_stopping=True,
    early_stopping_step=10,
):

    if torch.cuda.is_available():
        print("[INFO] Using GPU:{}\n".format(torch.cuda.get_device_name()))

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = np.inf
    history = defaultdict(list)
    early_stop_counter = 0

    for epoch in range(1, num_epochs + 1):
        gc.collect()

        train_epoch_loss, train_accuracy = train_one_epoch(
            model,
            optimizer,
            scheduler,
            dataloader= train_dataloader,
            device=device,
            epoch=epoch,
        )

        val_loss, val_accuracy = valid_one_epoch(
            model, valid_dataloader, device=device, epoch=epoch
        )

        history[f"{metric_prefix}Train Loss"].append(train_epoch_loss)
        history[f"{metric_prefix}Train Accuracy"].append(train_accuracy)
        history[f"{metric_prefix}Valid Loss"].append(val_loss)
        history[f"{metric_prefix}Valid Accuracy"].append(val_accuracy)


        print(f"Valid Loss : {val_loss}")

        if val_loss <= best_loss:
            early_stop_counter = 0

            print(
                f"Validation Loss improved( {best_loss} ---> {val_loss}  )"
            )

            # Update Best Loss
            best_loss = val_loss
            
            best_model_wts = copy.deepcopy(model.state_dict())

            PATH = "{}epoch{:.0f}_Loss{:.4f}.bin".format(file_prefix, epoch, best_loss)
            torch.save(model.state_dict(), PATH)
            torch.save(model.state_dict(), f"{file_prefix}best_{epoch}epoch.bin")

            print(f"Model Saved")

        elif early_stopping:
            early_stop_counter += 1
            if early_stop_counter > early_stopping_step:
                break
        
    end = time.time()
    time_elapsed = end - start
    print(
        "Training complete in {:.0f}h {:.0f}m {:.0f}s".format(
            time_elapsed // 3600,
            (time_elapsed % 3600) // 60,
            (time_elapsed % 3600) % 60,
        )
    )
    print("Best Loss: {:.4f}".format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, history


In [None]:
run_training(
    model = model,
    optimizer = optimizer,
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=100, eta_min=1e-5),
    device = device,
    num_epochs = 20000,
    metric_prefix="",
    file_prefix="",
    # early_stopping=True,
    # early_stopping_step=10,
)