In [1]:
from fast_transformers.attention import FullAttention,LinearAttention,CausalLinearAttention
import torch
from protein_dataset import Protein_Dataset,collate_fn
from torch.utils.data import DataLoader
from classifier import Protein_Classifier2,Protein_Classifier_LoRA
from tensorboardX import SummaryWriter
from tqdm import tqdm
from torchvision.ops.focal_loss import sigmoid_focal_loss
import torch.nn.functional as F
# from pynvml import *
from sklearn.metrics import average_precision_score
import numpy as np

N_BATCH =512



torch.manual_seed(0)

dataset=Protein_Dataset(split="train")
dataset_val=Protein_Dataset(split="val")

int2clss = dataset.int2clss

train_dataloader=DataLoader(dataset,batch_size=N_BATCH,collate_fn=collate_fn,shuffle=True)
val_dataloader=DataLoader(dataset,batch_size=N_BATCH,collate_fn=collate_fn)


In [2]:
# from fast_transformers.attention import LocalAttention

In [3]:
from functools import partial
from fast_transformers.feature_maps import Favor,SmoothedRandomFourierFeatures
# FavorAttention = partial(LinearAttention,query_dimensions=256//8,feature_map=partial(Favor,n_dims=64))
FavorAttention = partial(LinearAttention,query_dimensions=256//8,feature_map=Favor)

In [29]:
# layers = [FullAttention,FullAttention,FullAttention,FavorAttention,FavorAttention,FavorAttention]
layers = [LinearAttention for _ in range(6)]
# model = Protein_Classifier_LoRA(layer = LinearAttention,dim=256,n_layers=6,n_heads=8,dim_feedfwd=512,causal=False,r=4)
model = Protein_Classifier2(layers=layers,dim=256,n_layers=6,n_heads=8,dim_feedfwd=512,causal=False)
D = torch.load("./weights/linear_model.pth")["params"]
model.load_state_dict(D,strict=False)
model.cuda()
""

''

In [30]:
LoRA_Params = [param for n,param in model.named_parameters() if "LoRA_adapter" in n]

In [33]:
import numpy as np
params = sum([np.prod(p.size()) for p in model.parameters()])
l_params = sum([np.prod(p.size()) for p in LoRA_Params])
(l_params/params)*100,l_params,params


(0.0, 0, 3177761)

In [31]:
val_loss = 0
val_acc = 0
val_samples = 0
ys = []
preds = []
model.eval()
for src, tgt in tqdm(val_dataloader):
    with torch.no_grad():
        src=src.cuda()
        tgt=tgt.cuda()
        y_oh = F.one_hot(tgt,33)
        logits = model(src)
        loss = sigmoid_focal_loss(logits,y_oh.float()).mean()

        equals = (logits.sigmoid().argmax(1)==tgt).reshape(-1,1).detach().cpu()
        val_acc += torch.sum(equals.type(torch.FloatTensor)).item()
        val_loss += src.shape[0] * loss.item()
        val_samples += src.shape[0]

        ys.append(y_oh.cpu().numpy())
        preds.append(logits.softmax(dim=1).cpu().numpy())

ys = np.vstack(ys)
preds = np.vstack(preds)
test_dict = {}

aps = 0
for i in range(ys.shape[1]):
    ap = average_precision_score(ys[:,i:i+1],preds[:,i:i+1])
    aps += ap
    test_dict[int2clss[i]] = ap
    
mAP = aps/ys.shape[1]


val_loss = (val_loss/val_samples)
val_acc = (val_acc/val_samples)    

100%|█████████████████████████████████████████████████████████████████████████████████| 484/484 [01:42<00:00,  4.70it/s]


In [32]:
val_loss,val_acc,mAP

#Full Attn Weights
# Full : (0.002505175754773276, 0.9286366390628156, 0.9366763471453218)
# Linear: (0.018903504853127244, 0.262754998990103, 0.13435013169830362)


# From Back 1: (0.004965315841136858, 0.7645445364572814, 0.8383416077302199)
# From Back 2: (0.009145309982745726, 0.547247020803878, 0.6100587362213996)
# From Back 3:(0.011577269005403148, 0.44243991112906483, 0.3866045301048777)

#From Front 1: (0.009175984509262177, 0.5768814380933145, 0.5572050056187545)
#From Front 2: (0.016363995414181348, 0.3479256715814987, 0.3177880819238023)

(0.002836259463698521, 0.9066208846697636, 0.9040897279777705)

In [19]:
import time
model.eval()

Amazon_Classifier(
  (enc): Encoder(
    (enc): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (attention): AttentionLayer(
            (inner_attention): LinearAttention(
              (feature_map): ActivationFunctionFeatureMap()
            )
            (query_projection): Linear(in_features=256, out_features=256, bias=True)
            (key_projection): Linear(in_features=256, out_features=256, bias=True)
            (value_projection): Linear(in_features=256, out_features=256, bias=True)
            (out_projection): Linear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
          (linear2): Linear(in_features=512, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
      

In [22]:
t1 = time.time()
for _ in range(1000):
    with torch.no_grad():
        logits = model(src)
t2 = time.time()

t2-t1

#Linear : 6.503004312515259
#Full:7.042819261550903

6.503004312515259

In [23]:
7.042819261550903/6.503004312515259

1.083010086276085

In [24]:
src.shape

torch.Size([64, 191])

In [25]:
src

tensor([[ 36,  26,  16,  ...,   1,   1,   1],
        [694,  12,  11,  ...,   1,   1,   1],
        [  5, 118, 156,  ...,   1,   1,   1],
        ...,
        [  5, 110,  12,  ...,   1,   1,   1],
        [ 36,  11,   8,  ...,   1,   1,   1],
        [ 36, 189,  59,  ...,   1,   1,   1]], device='cuda:0')