# Loading data and traing a viral classifier on metagenomic sequences of mosquitoes

## Import dependencies 

In [1]:
import pandas as pd 
import numpy as np 

from torchmetagen.datasets.utils import FastaHandler, DatasetSplit, InflateDataset
from torchmetagen.datasets import metagenomicdataset as meta
from torchmetagen.models import DeepVirFinder, deepvirfinder
from torchmetagen.transforms import *
from torchvision import transforms as tf
import torch

from utils import *

## Check for GPU devices

In [2]:
device =  torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
path_to_file = 'Dataset_v1_2'

viral = FastaHandler(path_to_file, 'viral.fasta',)
nonviral= FastaHandler(path_to_file, 'nonviral.fasta',)



In [4]:
splitter = DatasetSplit({'train':0.7,'val':0.3 })

viral_train, viral_test= splitter(viral)
nonviral_train, nonviral_test= splitter(nonviral)


In [5]:

inflate=InflateDataset(method='truncated', tol=0.5, chunk_size=500)

viral_train_inflated = inflate(viral_train)
viral_test_inflated = viral_test #inflate(viral_test)


nonviral_train_inflated =  inflate(nonviral_train)
nonviral_test_inflated = nonviral_test#inflate(nonviral_test)

In [6]:
transforms_train=tf.Compose([
    ReverseComplement(),
    ToOneHot(['G','T', 'C', 'A']),
    ToTensor('one-hot')
])

transforms_test=tf.Compose([
    ReverseComplement(),
    ToOneHot(['G','T', 'C', 'A']),
    ToTensor('one-hot')
])


dataset_train= meta.MetagenomicSequenceData(pd.DataFrame({"data":np.concatenate((nonviral_train_inflated, viral_train_inflated)),
                                                          "class":np.concatenate((np.repeat("nonviral",len(nonviral_train_inflated)),
                                                                                  np.repeat("viral",len(viral_train_inflated))))}),
                                                     labels=['nonviral', 'viral'], transform=transforms_train)

dataset_test= meta.MetagenomicSequenceData(pd.DataFrame({"data":np.concatenate((nonviral_test_inflated, viral_test_inflated)),
                                                         "class":np.concatenate((np.repeat("nonviral",len(nonviral_test_inflated)),
                                                                                 np.repeat("viral",len(viral_test_inflated))))}),
                                                     labels=['nonviral', 'viral'], transform=transforms_test)

dataset={'train': dataset_train, 'val': dataset_test}
dataset_sizes = {'train':len(dataset_train), 'val':len(dataset_test)}

In [7]:
dataloaders = genDataLoader(dataset, {'train':250, 'val':1})

In [8]:
model_torch = deepvirfinder(pretrained=False, progress=True, M = 1000, K = 10, N = 1000)


In [9]:

#device =  torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model_torch.parameters(), lr = 1e-4)
criterion = torch.nn.BCELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,4)
per_epoch, per_batch = train_model(model_torch.to(device), criterion, optimizer, 
                      scheduler, dataloaders, device, dataset_sizes, num_epochs=1)

Epoch 0/0
----------
train Loss: 0.6962 Acc: 0.5067


[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.


val Loss: 0.6701 Acc: 0.8335

Training complete in 1m 23s
Best val Acc: 0.833521


In [10]:


pd.DataFrame(evaluate(model_torch.to(device), dataloaders['val'], device))


Acc test_set: 0.8335


Unnamed: 0,0,1,accuracy,macro avg,weighted avg
precision,0.834842,0.25,0.833521,0.542421,0.738138
recall,0.997972,0.003413,0.833521,0.500692,0.833521
f1-score,0.909147,0.006734,0.833521,0.45794,0.759933
support,1479.0,293.0,0.833521,1772.0,1772.0
