In [None]:
class Configs:
    def __init__(self):
        self.manifest_file = "total_am.txt"
        self.labels_path = "labels.csv"
        self.train_ratio = 0.8
        self.num_workers = 4
        self.batch_size = 64
        self.sample_mode = 'random' #'smart'
        self.teacher_forcing_ratio = 0.0
        
        self.num_classes = 2001
        self.d_model = 512
        self.d_ff = 2048
        self.num_heads = 4
        self.num_layers = 3
        self.model_name = "BERT"
        
configs = Configs()

In [None]:
from Tokenizer import Tokenizer
from data_module import DataModule


tokenizer = Tokenizer(label_file=configs.labels_path)
data_module = DataModule(configs, tokenizer)
train_dataloader = data_module.get_dl("train")
valid_dataloader = data_module.get_dl("valid")

In [None]:
from Model import Transformer_LM

model = Transformer_LM(
    num_classes=configs.num_classes,
    d_model=configs.d_model,
    d_ff=configs.d_ff,
    num_heads=configs.num_heads,
    num_layers=configs.num_layers,
    model=configs.model_name
)

In [None]:
model = model.cuda()

In [None]:
from criterion import CrossEntropyLoss
from torch.optim import Adam

Loss = CrossEntropyLoss(tokenizer)
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
from torch.utils.tensorboard import SummaryWriter
import torch


writer = SummaryWriter('runs/bert')

for iteration, (inputs, seq_lengths, targets) in enumerate(train_dataloader):
    inputs = inputs.cuda()
    targets = targets.cuda()
    optimizer.zero_grad()
    logits, preds = model(inputs, seq_lengths)
    loss = Loss(logits, targets)
    perplexity = torch.exp(loss)
    writer.add_scalar("train_loss", loss, iteration)
    writer.add_scalar("train_perplexity", perplexity, iteration)
    
    loss.backward()
    optimizer.step()
    
    if iteration % 1000 == 0 and iteration != 0:
        VAL_LOSS = 0
        val_iter = 0
        for i, (val_inputs, val_lengths, val_targets) in enumerate(valid_dataloader):
            if i > 100:
                break
            val_inputs = val_inputs.cuda()
            val_targets = val_targets.cuda()
            with torch.no_grad():
                logits, preds = model(val_inputs, val_lengths)
            val_loss = Loss(logits, val_targets)
            VAL_LOSS += val_loss
            val_iter += 1
        validation_loss = VAL_LOSS/val_iter
        validation_perplexity = torch.exp(validation_loss)
        writer.add_scalar("validation_loss", validation_loss, iteration)
        writer.add_scalar("validation_perplexity", validation_perplexity, iteration)
            
writer.close()

In [None]:
torch.save(model.state_dict(), "bert.pt")