In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from datasets import load_dataset
from tqdm.notebook import tqdm

from navec import Navec

import matplotlib.pyplot as plt
import numpy as np

In [2]:
from extractor.model import Extractor
from extractor.utils import create_target_mask, TaggingDataset

In [3]:
dataset = load_dataset('IlyaGusev/gazeta', revision="v2.0")

No config specified, defaulting to: gazeta/default
Found cached dataset gazeta (/home/goncharovglebig/.cache/huggingface/datasets/IlyaGusev___gazeta/default/2.0.0/c329f0fc1c22ab6e43e0045ee659d0d43c647492baa2a6ab3a5ea7dac98cd552)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
train_dataset = TaggingDataset(dataset['train']['text'], dataset['train']['summary'])
val_dataset = TaggingDataset(dataset['validation']['text'], dataset['validation']['summary'])

  0%|          | 0/60964 [00:00<?, ?it/s]

  0%|          | 0/6369 [00:00<?, ?it/s]

In [5]:
navec = Navec.load('navec_hudlit_v1_12B_500K_300d_100q.tar')
navec_emb_size = 300
vocab = navec.vocab

In [6]:
def collate_batch(batch):
    art_list, target_list = [], []
    for art, target in batch:
        art_encoded = [vocab[tok] if tok in vocab else vocab.unk_id
                       for tok in art]
        art_list.append(torch.tensor(np.array(art_encoded)))
        target_list.append(torch.FloatTensor(np.array(target)))
        
    art_tensor_padded = pad_sequence(art_list, padding_value=vocab.pad_id).T
    target_tensor_padded = pad_sequence(target_list, padding_value=0).T
    return art_tensor_padded.to(device), target_tensor_padded.to(device)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=256, collate_fn=collate_batch)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = Extractor()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
min_loss = 99999
model.to(device)

params = list(model.parameters())
optimizer = torch.optim.Adagrad(params, lr=0.1, initial_accumulator_value=0.1)
criternion = nn.BCEWithLogitsLoss()

epoch_num = 45
train_loss_list, val_loss_list = [], []

for ep in tqdm(range(epoch_num)):
    for train_batch in tqdm(train_loader):
        src, target = train_batch
        out = model(src).squeeze()
        
        loss = criternion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss = loss.detach().item()
        train_loss_list.append(train_loss)

    #Validation loop
    val_batch_loss = []
    with torch.no_grad():
        for val_batch in val_loader:
            src, target = train_batch
            out = model(src).squeeze()
            loss = criternion(out, target)
            val_batch_loss.append(loss.detach().item())
    val_loss = np.mean(val_batch_loss)
    val_loss_list.append(val_loss)
    if val_loss < min_loss:
        min_loss = val_loss
        print('Saving best model')
        torch.save(model, f'extractor.pth')

    print(f'For epoch #{ep} train loss {np.mean(train_loss_list[-250_000:])}, val loss {val_loss}')
    

  torch.from_numpy(navec.pq.indexes),


  0%|          | 0/45 [00:00<?, ?it/s]

  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #0 train loss 0.5898529271201609, val loss 0.5955839490890503


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #1 train loss 0.554633400929024, val loss 0.5057059812545777


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #2 train loss 0.5214323616759381, val loss 0.47204349279403685


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #3 train loss 0.49987559656468394, val loss 0.4620335936546326


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #4 train loss 0.4853986437101244, val loss 0.45614701628684995


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #5 train loss 0.4749630330081099, val loss 0.4515608847141266


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #6 train loss 0.46697327551625123, val loss 0.4472822380065918


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #7 train loss 0.4605655593945641, val loss 0.4438824439048767


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #8 train loss 0.4552731085023122, val loss 0.44088009238243103


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #9 train loss 0.4508140224293186, val loss 0.4386892640590668


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #10 train loss 0.4469998479682985, val loss 0.4368288171291351


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #11 train loss 0.4436897514089215, val loss 0.43562416076660154


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #12 train loss 0.44078714787326984, val loss 0.43397526144981385


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #13 train loss 0.4382099921893589, val loss 0.432737729549408


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #14 train loss 0.4358986149273158, val loss 0.43190905928611756


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #15 train loss 0.4338081375497031, val loss 0.4308781802654266


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #16 train loss 0.43190699638241825, val loss 0.4296653258800507


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #17 train loss 0.4301665836519998, val loss 0.4289004194736481


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #18 train loss 0.42856388679030605, val loss 0.42791874885559084


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #19 train loss 0.42708275333109263, val loss 0.4269857668876648


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #20 train loss 0.42570571909270977, val loss 0.4259749293327332


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #21 train loss 0.4244240636821481, val loss 0.4253085362911224


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #22 train loss 0.4232281451770039, val loss 0.4243180763721466


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #23 train loss 0.4221083015421694, val loss 0.4233721625804901


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #24 train loss 0.4210553166856327, val loss 0.4221030235290527


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #25 train loss 0.4200642124935865, val loss 0.42080990195274354


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #26 train loss 0.4191191824170096, val loss 0.41888948798179626


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #27 train loss 0.41820443840669574, val loss 0.41566486477851866


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #28 train loss 0.4173072428951668, val loss 0.41433671951293943


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #29 train loss 0.41643998115381936, val loss 0.41378650546073914


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #30 train loss 0.4156080949063532, val loss 0.4123861074447632


  0%|          | 0/239 [00:00<?, ?it/s]

For epoch #31 train loss 0.41481168413433445, val loss 0.4130461835861206


  0%|          | 0/239 [00:00<?, ?it/s]

Saving best model
For epoch #32 train loss 0.4140474852257179, val loss 0.4110026872158051


  0%|          | 0/239 [00:00<?, ?it/s]

For epoch #33 train loss 0.4133167070669697, val loss 0.4113087594509125


  0%|          | 0/239 [00:00<?, ?it/s]

In [None]:
plt.plot(train_loss_list, label='train');
plt.legend();

In [None]:
plt.plot(val_loss_list, label='val');
plt.legend();