In [None]:
import torch
import transformers
from catalyst import dl
from src.runners import DistilMLMRunner
from src.data import MLMDataset
from src.models import DistilbertStudentModel, BertForMLM
import pandas as pd

In [None]:
try:
    train_df = pd.read_csv("data/lenta-ru-news.csv")[:10000]
    valid_df = pd.read_csv("data/lenta-ru-news.csv")[10000:12000]
except:
    ! bin/download_lenta.sh
    train_df = pd.read_csv("data/lenta-ru-news.csv")[:10000]
    valid_df = pd.read_csv("data/lenta-ru-news.csv")[10000:12000]

In [None]:
model_name = "DeepPavlov/rubert-base-cased"

train_dataset = MLMDataset(train_df["text"], 
                           model_name=model_name)
valid_dataset = MLMDataset(valid_df["text"], 
                           model_name=model_name)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=2, shuffle=True)
loaders = {"train": train_dataloader, "valid": valid_dataloader}

In [None]:
teacher = BertForMLM(model_name)
student = DistilbertStudentModel(model_name)

model = torch.nn.ModuleDict({"teacher": teacher, "student": student})

In [None]:
runner = DistilMLMRunner(device=torch.device("cuda"))
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
runner.train(
    model=torch.nn.DataParallel(model),
    optimizer=optimizer,
    loaders=loaders,
    verbose=True,
    num_epochs=10,
    callbacks={
        "optimizer": dl.OptimizerCallback(
            metric_key="loss",     # you can also pass 'mae' to optimize it instea
            accumulation_steps=1,  # also you can pass any number of steps for gradient accumulation                           
        )
    }
)