In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [2]:
from typing import List

import time 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor, einsum
from einops import parse_shape, rearrange, repeat

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_parameters_all(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

In [75]:
# Text 005

from brainle.models.sm_model import SMModel
from brainle.models.architectures.attention import ConvTeNet
from brainle.datamodules.wikitext_datamodule import WikiTextDatamodule

block_size = 1024

datamodule = WikiTextDatamodule(
    train_val_split =  [10537, 100],
    batch_size = 24, 
    num_workers = 0,
    block_size = block_size,
    p_word_mask = 0.15,
    p_char_mask = 0.05
)
datamodule.setup()

net = ConvTeNet(
    vocabulary_size = 837,
    embedding_dim = 256,
    num_layers = 7,
    num_heads = 8,
    num_attention_layers = 4,
    window_size = 4,
    use_skip = True
)
model = SMModel.load_from_checkpoint(
    checkpoint_path = '../data/ckpts/text_008_last.ckpt', 
    model = net,
    learning_rate = 1e-4
)

text = 'WXat doXs XXXs XXXn meXX? XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX '
input_tokens = rearrange(torch.tensor(datamodule.dataset.encode(text.ljust(block_size, '.'))), 'n -> 1 n')
input_mask = rearrange(torch.tensor([1] * block_size), 'n -> 1 n')

for i in range(len(text)):
    if text[i] == 'X':
        input_mask[:,i] = 0

out = F.softmax(model(input_tokens, input_mask), dim=-1)
ids = torch.topk(out, k=1, dim=-1)[1]
ids = rearrange(ids, "1 s 1 -> s").numpy()
out_text = datamodule.dataset.decode(ids)
out_text

Reusing dataset wikitext (/Users/flavioschneider/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


Alphabet size 837


'What does t ts tirn met ? Ieteetettee  einttin  o   a  ei  e eeeine  e aeotiia  in  a eeie eeititttios ................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................