<a href="https://colab.research.google.com/github/kryuchkovdm/Distillation/blob/master/methods/Student_dist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def train_epoch(train_iter, model, optim, epoch_num, distil=False,alpha = 0.5,temperature=10):
    train_loss = 0
    train_acc = 0
    y_true = []
    y_pred = []
    
    model.to(device)
    model.train()
    
    if distil:
        kldloss = nn.KLDivLoss()  
        cost = nn.CrossEntropyLoss()
    else:
        cost = nn.CrossEntropyLoss()

    for batch in tqdm(train_iter, total=len(train_iter), desc=f"Batch progress for epoch {epoch_num}"):
        
        batch = tuple([t.to(device) for t in batch])
        inputs = {"input_ids": batch[0],
                  "attention_mask": batch[1]}
        labels = batch[3]

        optim.zero_grad()
        
        student_logits = model(inputs.get('input_ids'))
        log = student_logits
        soft_predictions = F.log_softmax( log / temperature, dim=1 )

        if distil:
            target = batch[4]
            tar = target.clone()
            teacher_logits = F.softmax( tar / temperature, dim=1 )
            distillation_loss = kldloss(soft_predictions, teacher_logits)
            target = labels
        else:
            target = labels

        batch_loss = cost(student_logits, target)


        if torch.isnan(batch_loss):
            print("NAN batch loss!", epoch_num, batch_loss, student_logits, target)
        if distil:
            train_loss += batch_loss.item()*(1-alpha) + distillation_loss.item()*(alpha)
        else:
            train_loss += batch_loss.item()

        batch_acc = (student_logits.argmax(1) == labels).sum().item()
        train_acc += batch_acc
        y_true.extend(labels.tolist())
        y_pred.extend(student_logits.argmax(1).tolist())

        batch_loss.backward()
        optim.step()

    return train_loss / len(train_iter), train_acc / len(train_iter.dataset), f1_score(y_true, y_pred, average="macro")

In [None]:
def validate(test_iter, model):
    test_acc = 0 
    test_loss = 0
    y_true = []
    y_pred = []

    cost = nn.CrossEntropyLoss()

    model.to(device)
    model.eval()

    for batch in tqdm(test_iter, desc="Validating"):
        
        batch = tuple([t.to(device) for t in batch])
        inputs = {"input_ids": batch[0],
                  "attention_mask": batch[1],
                  "token_type_ids": batch[2]}
        labels = batch[3]

        with torch.no_grad():
            
            output = model(inputs.get('input_ids'))
            batch_loss = cost(output, labels)
            test_loss += batch_loss.item()
                    
            batch_acc = (output.argmax(1) == labels).sum().item() 
            test_acc += batch_acc
            y_true.extend(labels.tolist())
            y_pred.extend(output.argmax(1).tolist())

    return test_loss / len(test_iter), test_acc / len(test_iter.dataset), f1_score(y_true, y_pred, average="macro")

In [None]:
def train_loop(model, optim, train_loader, test_loader, n_epochs=5, sched=None, distil=False,alpha = 0.5,temperature=10):
    training_results = {"epoch": list(range(n_epochs)),
                        "train_loss": [],
                        "train_acc": [],
                        "train_f1_macro": [],
                        "test_loss": [],
                        "test_acc": [],
                        "test_f1_macro": []}

    model.to(device)

    try:
        for i in range(n_epochs):
            
            train_loss, train_acc, train_f1 = train_epoch(train_loader, model, optim, epoch_num=i, distil=distil,alpha = alpha,temperature=temperature)
            if sched is not None:
                sched.step()
            test_loss, test_acc, test_f1 = validate(test_loader, model)
            training_results["train_loss"].append(train_loss)
            training_results["train_acc"].append(train_acc)
            training_results["train_f1_macro"].append(train_f1)
            training_results["test_loss"].append(test_loss)
            training_results["test_acc"].append(test_acc)
            training_results["test_f1_macro"].append(test_f1)
    except KeyboardInterrupt:
        pass

    return pd.DataFrame(training_results)