In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from src.dataset.dataset import TextDataset
import torch 
from torch import nn

In [3]:
data = TextDataset(
    corpus_path = 'data/corpus.txt',
    json_path='data/tiny_stories',
    save_tokenizer_to = 'bpe',
    max_len = 256,
    vocab_size = 5000
)
print('created train data with size:', len(data))

0it [00:00, ?it/s]

created train data with size: 20869767


In [4]:
generator = torch.Generator().manual_seed(42)
datasets = torch.utils.data.random_split(data, [0.95, 0.05], generator=generator)
train_data = datasets[0]
test_data = datasets[1]
print(len(train_data))
print(len(test_data))

19826279
1043488


In [4]:
from torch.utils.data import DataLoader
from src.model.language_model import LanguageModel
from src.dataset.dataset import Collator
from src.utils.trainer import CosineAnnealingWithWarmupLR

collate_fn = Collator(pad_value=data.pad_id)
train_loader = DataLoader(train_data, batch_size=768, shuffle=True, num_workers=4, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=768, shuffle=False, num_workers=4, collate_fn=collate_fn)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LanguageModel(
    embed_dim=1024,
    vocab_size=data.vocab_size,
    max_len=data.max_len,
    pad_idx=data.pad_id,
    num_layers=16,
    num_heads=32,
    dropout=0,
    feedforward_dim=2048
)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
# lr_scheduler = CosineAnnealingWithWarmupLR(optimizer=optimizer, warmup_steps=50, max_steps=200)

In [5]:
print(model)

LanguageModel(
  (transformer): TransformerDecoder(
    (embedding): Embedding(5000, 1024, padding_idx=0)
    (pos_encoder): PositionalEncoding()
    (decoder): ModuleList(
      (0-15): 16 x DecoderBlock(
        (Q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (K_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (V_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (masked_multihead_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (lay_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (feedforward): Sequential(
          (0): Linear(in_features=1024, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=1024, bias=True)
        )
        (dropout): Dropout(p=0, inplace=False)
        (lay_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)

In [7]:
from src.utils.wandb_logger import WandbLogger
from src.utils.trainer import train

wdb = WandbLogger(
    project_name="little-lama",
    config={"loh": "loh"}
)
train(
    model, optimizer, criterion, train_loader, 
    train_loader, 100, DEVICE, wdb, 
    log_output = False, grad_clipping=10
)



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111512737762597, max=1.0)…

=== Epoch 1 ===


100%|██████████| 4/4 [00:03<00:00,  1.19it/s]
100%|██████████| 4/4 [00:00<00:00, 11.81it/s]


=== Epoch 2 ===


100%|██████████| 4/4 [00:00<00:00, 18.19it/s]
100%|██████████| 4/4 [00:00<00:00, 16.85it/s]


=== Epoch 3 ===


100%|██████████| 4/4 [00:00<00:00, 58.89it/s]
100%|██████████| 4/4 [00:00<00:00, 15.26it/s]


=== Epoch 4 ===


100%|██████████| 4/4 [00:00<00:00, 47.19it/s]
100%|██████████| 4/4 [00:00<00:00, 15.37it/s]


=== Epoch 5 ===


100%|██████████| 4/4 [00:00<00:00, 15.71it/s]
100%|██████████| 4/4 [00:00<00:00, 13.82it/s]


=== Epoch 6 ===


100%|██████████| 4/4 [00:00<00:00, 57.93it/s]
100%|██████████| 4/4 [00:00<00:00, 13.67it/s]


=== Epoch 7 ===


100%|██████████| 4/4 [00:00<00:00, 24.06it/s]
100%|██████████| 4/4 [00:00<00:00, 14.19it/s]


=== Epoch 8 ===


100%|██████████| 4/4 [00:00<00:00, 18.45it/s]
100%|██████████| 4/4 [00:00<00:00, 12.68it/s]


=== Epoch 9 ===


100%|██████████| 4/4 [00:00<00:00, 46.58it/s]
100%|██████████| 4/4 [00:00<00:00, 14.13it/s]


=== Epoch 10 ===


100%|██████████| 4/4 [00:00<00:00, 45.48it/s]
100%|██████████| 4/4 [00:00<00:00, 19.25it/s]


=== Epoch 11 ===


100%|██████████| 4/4 [00:00<00:00, 17.54it/s]
100%|██████████| 4/4 [00:00<00:00, 17.44it/s]


=== Epoch 12 ===


100%|██████████| 4/4 [00:00<00:00, 69.93it/s]
100%|██████████| 4/4 [00:00<00:00, 19.11it/s]


=== Epoch 13 ===


100%|██████████| 4/4 [00:00<00:00, 20.20it/s]
100%|██████████| 4/4 [00:00<00:00, 16.58it/s]


=== Epoch 14 ===


100%|██████████| 4/4 [00:00<00:00, 62.03it/s]
100%|██████████| 4/4 [00:00<00:00, 15.07it/s]


=== Epoch 15 ===


100%|██████████| 4/4 [00:00<00:00, 24.17it/s]
100%|██████████| 4/4 [00:00<00:00, 12.89it/s]


=== Epoch 16 ===


100%|██████████| 4/4 [00:00<00:00, 64.76it/s]
100%|██████████| 4/4 [00:00<00:00, 11.40it/s]


=== Epoch 17 ===


100%|██████████| 4/4 [00:00<00:00, 56.39it/s]
100%|██████████| 4/4 [00:00<00:00, 17.09it/s]


=== Epoch 18 ===


100%|██████████| 4/4 [00:00<00:00, 14.59it/s]
100%|██████████| 4/4 [00:00<00:00, 15.19it/s]


=== Epoch 19 ===


100%|██████████| 4/4 [00:00<00:00, 45.01it/s]
100%|██████████| 4/4 [00:00<00:00, 19.20it/s]


=== Epoch 20 ===


100%|██████████| 4/4 [00:00<00:00, 46.94it/s]
100%|██████████| 4/4 [00:00<00:00, 16.58it/s]


=== Epoch 21 ===


100%|██████████| 4/4 [00:00<00:00, 25.10it/s]
100%|██████████| 4/4 [00:00<00:00, 14.93it/s]


=== Epoch 22 ===


100%|██████████| 4/4 [00:00<00:00, 55.37it/s]
100%|██████████| 4/4 [00:00<00:00, 19.99it/s]


=== Epoch 23 ===


100%|██████████| 4/4 [00:00<00:00, 39.91it/s]
100%|██████████| 4/4 [00:00<00:00, 12.42it/s]


=== Epoch 24 ===


100%|██████████| 4/4 [00:00<00:00, 16.45it/s]
100%|██████████| 4/4 [00:00<00:00, 15.10it/s]


=== Epoch 25 ===


100%|██████████| 4/4 [00:00<00:00, 60.24it/s]
100%|██████████| 4/4 [00:00<00:00, 15.09it/s]


=== Epoch 26 ===


100%|██████████| 4/4 [00:00<00:00, 20.82it/s]
100%|██████████| 4/4 [00:00<00:00, 15.77it/s]


=== Epoch 27 ===


100%|██████████| 4/4 [00:00<00:00, 53.25it/s]
100%|██████████| 4/4 [00:00<00:00, 15.03it/s]


=== Epoch 28 ===


100%|██████████| 4/4 [00:00<00:00, 20.97it/s]
100%|██████████| 4/4 [00:00<00:00, 14.09it/s]


=== Epoch 29 ===


100%|██████████| 4/4 [00:00<00:00, 49.59it/s]
100%|██████████| 4/4 [00:00<00:00, 14.98it/s]


=== Epoch 30 ===


100%|██████████| 4/4 [00:00<00:00, 17.45it/s]
100%|██████████| 4/4 [00:00<00:00, 14.16it/s]


=== Epoch 31 ===


100%|██████████| 4/4 [00:00<00:00, 69.82it/s]
100%|██████████| 4/4 [00:00<00:00, 16.99it/s]


=== Epoch 32 ===


100%|██████████| 4/4 [00:00<00:00, 52.32it/s]
100%|██████████| 4/4 [00:00<00:00, 12.73it/s]


=== Epoch 33 ===


100%|██████████| 4/4 [00:00<00:00, 26.10it/s]
100%|██████████| 4/4 [00:00<00:00, 17.33it/s]


=== Epoch 34 ===


100%|██████████| 4/4 [00:00<00:00, 60.92it/s]
100%|██████████| 4/4 [00:00<00:00, 16.49it/s]


=== Epoch 35 ===


100%|██████████| 4/4 [00:00<00:00, 16.45it/s]
100%|██████████| 4/4 [00:00<00:00, 17.73it/s]


=== Epoch 36 ===


100%|██████████| 4/4 [00:00<00:00, 80.33it/s]
100%|██████████| 4/4 [00:00<00:00, 17.59it/s]


=== Epoch 37 ===


100%|██████████| 4/4 [00:00<00:00, 47.64it/s]
100%|██████████| 4/4 [00:00<00:00, 16.51it/s]


=== Epoch 38 ===


100%|██████████| 4/4 [00:00<00:00, 17.14it/s]
100%|██████████| 4/4 [00:00<00:00, 16.47it/s]


=== Epoch 39 ===


100%|██████████| 4/4 [00:00<00:00, 75.73it/s]
100%|██████████| 4/4 [00:00<00:00, 16.86it/s]


=== Epoch 40 ===


100%|██████████| 4/4 [00:00<00:00, 21.38it/s]
100%|██████████| 4/4 [00:00<00:00, 16.75it/s]


=== Epoch 41 ===


100%|██████████| 4/4 [00:00<00:00, 64.60it/s]
100%|██████████| 4/4 [00:00<00:00, 15.61it/s]


=== Epoch 42 ===


100%|██████████| 4/4 [00:00<00:00, 48.83it/s]
100%|██████████| 4/4 [00:00<00:00, 14.06it/s]


=== Epoch 43 ===


100%|██████████| 4/4 [00:00<00:00, 17.01it/s]
100%|██████████| 4/4 [00:00<00:00, 14.68it/s]


=== Epoch 44 ===


100%|██████████| 4/4 [00:00<00:00, 55.90it/s]
100%|██████████| 4/4 [00:00<00:00, 17.45it/s]


=== Epoch 45 ===


100%|██████████| 4/4 [00:00<00:00, 17.28it/s]
100%|██████████| 4/4 [00:00<00:00, 13.78it/s]


=== Epoch 46 ===


100%|██████████| 4/4 [00:00<00:00, 66.38it/s]
100%|██████████| 4/4 [00:00<00:00, 14.27it/s]


=== Epoch 47 ===


100%|██████████| 4/4 [00:00<00:00, 45.65it/s]
100%|██████████| 4/4 [00:00<00:00, 13.84it/s]


=== Epoch 48 ===


100%|██████████| 4/4 [00:00<00:00, 16.87it/s]
100%|██████████| 4/4 [00:00<00:00, 12.15it/s]


=== Epoch 49 ===


100%|██████████| 4/4 [00:00<00:00, 53.69it/s]
100%|██████████| 4/4 [00:00<00:00, 14.02it/s]


KeyboardInterrupt: 

итоговый ран из wandb