In [None]:
!pip install tokenizers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/cgjeong23/Viral-genomic-classification.git virus

## Imports

In [None]:
from virus.ML.model import SkipGramEmbeddingModel
from virus.ML.train import train, evaluate
from virus.ML.dataloader import load_sequences, sample_data, get_3_splits, SequenceDataset
from torch import nn
from torch.utils.data import DataLoader

import numpy as np
import os

%load_ext autoreload
%autoreload 2

# Do Training

In [None]:
use_google_drive = False
use_kaggle = True

In [None]:
google_drive_path = '/content/drive/MyDrive'

if use_google_drive:
    base_path = f'{google_drive_path}/trainingdata'
    tokenizer_file = f'{google_drive_path}/gene_tokenizer.json' 
elif use_kaggle:
    base_path = '../input/pacific-sra/trainingdata'
    tokenizer_file = '../input/pacific-sra/gene_tokenizer.json'
else:
    base_path = 'trainingdata'
    tokenizer_file = 'gene_tokenizer.json'

sequences, labels = load_sequences(base_path, train_embedding=True)

(train_seq, valid_seq, test_seq,
 train_label, valid_label, test_label) = get_3_splits(sequences, labels)

label_dict = {k: i for i, k in enumerate(np.unique(labels))}

train_dataset = SequenceDataset(train_seq, train_label, tokenizer_file=tokenizer_file,
                                label_dict=label_dict)
valid_dataset = SequenceDataset(valid_seq, valid_label, tokenizer_file=tokenizer_file,
                                label_dict=label_dict)
test_dataset = SequenceDataset(test_seq, test_label, tokenizer_file=tokenizer_file,
                               label_dict=label_dict)

In [None]:
lr = 1e-2
batch_size = 5000
num_epochs = 1
vocab_size = train_dataset.tokenizer.get_vocab_size()
pad_id = train_dataset.tokenizer.padding['pad_id']
embedding_dim = 256
window_size = 2

In [None]:
import torch
model = SkipGramEmbeddingModel(vocab_size, embedding_dim, pad_id, window_size)
model = model.to('cuda')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=valid_dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=test_dataset.collate_fn)
loss_function = nn.CrossEntropyLoss(ignore_index=pad_id)

In [None]:
kaggle_path = '/kaggle/working'

if use_kaggle:
    save_path = kaggle_path
elif use_google_drive:
    save_path = google_drive_path
else:
    save_path = '.'

loss_history = train(model, train_loader, loss_function, lr, num_epochs, 
                    valid_loader=valid_loader, test_loader=test_loader, train_skip_gram=True,
                    base_path=save_path)

In [None]:
loss_history

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_history['train'], label='train')
plt.plot(loss_history['valid'], label='valid')
plt.plot(loss_history['test'], label='test')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()