In [1]:
from load_data import CreateDataLoader 
from networks.embedding_net import EmbedderAndPositionEncoder
from networks.attention_net import TransformerBlock
from networks.classifier_net import Classification 
from networks.transformer_net import TransformerClassification
import torch
from trainer import trainer 
from evaluate import evaluate 

### load dataset 

In [2]:
loader = CreateDataLoader()
train, val, test = loader.make_loader()
vocab_size = loader.vocab_size

sample = None 
for r in train:
    sample = r["input_ids"]
    break 
    
print(sample.size())
print(vocab_size)


torch.Size([32, 256])
121855


### model training and best model saving to `.onnx`

In [3]:
net = TransformerClassification(vocab_size=vocab_size,
                               n_token=sample.size()[1],
                               embedding_dim=300,
                               tag_size=2)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=2e-5)

model = trainer(train, val, net, criterion, optimizer, 10)

device: cpu


KeyboardInterrupt: 

### model evaluate test datasets 

In [None]:
evaluate(test, model, criterion)

### visualization **Attetion-weights**

In [4]:
from visualization_attention import mk_html
from IPython.display import HTML

In [9]:
sample = None 
for r in test:
    sample = r 
    break 

inputs = sample["input_ids"]
with torch.no_grad():
    output, attn = net(inputs, attention_flg=True)
    
preds = output.argmax(-1)
word2index = loader.word2index

html_text = mk_html(3, sample, preds, torch.rand(4, 256, 256), word2index)

In [10]:
HTML(html_text)

### inferences 

In [11]:
from preprocessing.dump_preprocessing import load_dump_prep

In [18]:
transform = load_dump_prep()
dummy_text = "hello world"
dummy_inputs = transform.transform(dummy_text)

with torch.no_grad():
    pred = net(dummy_inputs)
    
pred = pred.argmax(-1)
if pred == 0:
    pred_str = "Negative"
else:
    pred_str = "Positive"
print(f"input text: {dummy_text}")
print(f"predict: {pred_str}")

input text: hello world
predict: Positive
