In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import fastai
import math
from functools import partial

  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
ROOT = "../../data/protein/classification/sample_512/"
DATA_PATH = ROOT+"1_kmers"
MODEL_PATH = "../../weights/protein/classification/sample_512/test"
SEQUENCE_LENGTH=512
VOCAB_SIZE=20
BERT_CONFIG_FILE = "../../../bert/config/bert_config_file.json"
BERT_WEIGHTS = "../../../bert_pytorch/weights/tpu"

In [3]:
epochs = 1
num_workers = 8 # On cloud 8
batch_size = 64

In [4]:
class BertDataSet(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, path, seq_length, vocab_size=20, records_to_test = 100000):
        self.data = np.load(path)[:records_to_test]
        self.seq_length = seq_length
        self.vocab_size = vocab_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data[idx]
        seq = np.asarray(row[1])
        mask = np.ones(len(row[1]))
        masked_seq = np.copy(seq)        
        indicies = np.sort(np.random.randint(0, len(seq)-1, 20))
        np.put(masked_seq, indicies, 21)
        seq, mask, masked_seq = self.pad([seq, mask, masked_seq], self.seq_length-len(seq))
        label = np.take(seq, indicies)
        output = self.to_int64([masked_seq, mask, indicies, seq, label])
        return output[0:4], output[4]
    
    def pad(self, items_to_pad, pad_width):
        for i in range(len(items_to_pad)):
            items_to_pad[i] = np.pad(items_to_pad[i], mode="constant", pad_width=(0,pad_width))        
        return items_to_pad
        
    def to_int64(self, items_to_convert):
        for i in range(len(items_to_convert)):
            items_to_convert[i] = np.int64(items_to_convert[i])
        return items_to_convert
        

In [5]:
test_ds = BertDataSet(DATA_PATH+"/test/data.npy", SEQUENCE_LENGTH)

In [6]:
(a, b, c, d), e = test_ds.__getitem__(1)
assert (0 == a).argmax(axis=0) == (0 == b).argmax(axis=0)
#assert (0 == a).argmax(axis=0) == (0 == f).argmax(axis=0)
np.count_nonzero(21 == a), c

(0,
 array([ 15,  18,  27,  49,  73,  77, 113, 117, 127, 140, 159, 179, 179, 181, 189, 190, 199, 199, 217, 220], dtype=int64))

In [7]:
from pytorch_pretrained_bert.modeling import *
class BertTest(nn.Module):
    def __init__(self, config):
        super(BertTest, self).__init__()
        self.bert = BertForPreTraining(config)
        
    def forward(self, masked_seq, attention_mask, indicies, masked_lm_labels, token_ids=None):
        prediction_scores, _ = self.bert(masked_seq, 
                                         token_type_ids=token_ids, 
                                         attention_mask=attention_mask, 
                                         masked_lm_labels=masked_lm_labels)
        return prediction_scores

In [8]:
from pytorch_pretrained_bert.modeling import BertConfig
bert_config = BertConfig.from_json_file(BERT_CONFIG_FILE)

In [9]:
bert_test = BertTest(bert_config)
bert_test.bert.load_state_dict(torch.load(BERT_WEIGHTS, map_location='cpu'))
bert_test.to('cuda')

BertTest(
  (bert): BertForPreTraining(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(22, 512)
        (position_embeddings): Embedding(512, 512)
        (token_type_embeddings): Embedding(16, 512)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=512, out_features=512, bias=True)
                (key): Linear(in_features=512, out_features=512, bias=True)
                (value): Linear(in_features=512, out_features=512, bias=True)
                (dropout): Dropout(p=0.1)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=512, out_features=512, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
            

In [10]:
bert_test.bert.bert.embeddings.word_embeddings.weight[0][:10]

tensor([-0.0131,  0.0144, -0.1214,  0.0055, -0.0371, -0.0127, -0.0228, -0.0466,
        -0.0092, -0.0069], device='cuda:0', grad_fn=<SliceBackward>)

In [11]:
for n, p in bert_test.bert.named_parameters():
    if p.requires_grad: 
        p.requires_grad=False

In [12]:
sum(p.numel() for p in bert_test.parameters() if p.requires_grad)

0

In [13]:
from fastprogress.fastprogress import NBProgressBar
from fastai.basic_train import loss_batch
from fastai.basic_data import to_device
from fastai.torch_core import to_np

In [14]:
def batched_index_select(input, dim, index):
    views = [input.shape[0]] + [1 if i != dim else -1 for i in range(1, len(input.shape))]
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.view(views).expand(expanse)
    return torch.gather(input, dim, index)

In [15]:
dl = DataLoader(test_ds, batch_size=16, shuffle=True)

In [20]:
model = bert_test
loss_func = nn.CrossEntropyLoss()
model.eval()
with torch.no_grad():
    val_losses, val_acc, steps = [],[], 0
    for xb,yb in NBProgressBar(dl):
        xb, yb = to_device((xb, yb), torch.device('cuda'))
        out = model(*xb)
        predicted = batched_index_select(out,1, xb[2]).transpose(1, 2)
        val_losses.append(to_np(loss_func(predicted, yb)))
        val_acc.append(to_np(fastai.accuracy(predicted, yb)))
        steps = steps + 1
        if steps ==10:
            break
        
sum(val_losses)/len(val_losses), sum(val_acc)/len(val_acc)

(0.48197902739048004, 0.9487500131130219)

In [21]:
predicted.argmax(dim=1)[0]

tensor([16,  8,  4, 17, 18, 17, 13, 13,  6, 16,  3, 12,  5,  9, 10, 17,  3, 20,
         2,  2], device='cuda:0')

In [22]:
yb[0]

tensor([16,  8,  4, 20, 18, 17, 13, 13,  6, 16,  3, 12,  5,  9, 10, 17,  3, 20,
         2,  2], device='cuda:0')

In [19]:
xb[2][0]

tensor([  2,   8,  22,  24,  27,  46,  49,  73,  73,  92, 109, 113, 136, 166,
        194, 201, 212, 272, 280, 286], device='cuda:0')