In [1]:
#%%Forcolabonly
from google.colab import drive
drive.mount('/content/gdrive')
#%%

Mounted at /content/gdrive


In [1]:
# !pip install datasets
# !pip install transformers
# !pip install torchcontrib
import torch
from torch import nn
import torch.nn.functional as F
from torch import BoolTensor, FloatTensor, LongTensor
from typing import Optional
from datasets import load_dataset
import transformers
from transformers import DistilBertTokenizerFast, BertPreTrainedModel, get_linear_schedule_with_warmup, AdamW
from torchcontrib.optim import SWA
import regex as re
import numpy as np
from tqdm import tqdm

In [2]:
tags=sorted(list('.?!,;:-—…'))
tag2id = {tag: id+1 for id, tag in enumerate(tags)}
tag2id[' ']=0
tag2id['']=-100
id2tag = {id: tag for tag, id in tag2id.items()}
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
class PunctuationDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids:LongTensor, attention_mask:FloatTensor, labels:Optional[LongTensor] = None) -> None:
        """
        :param input_ids: tokenids
        :param attention_mask: attention_mask, null->0
        :param labels: true labels, optional
        :return None
        """
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __getitem__(self, idx):
        """:param idx: implement index"""
        return {'input_ids': torch.as_tensor(self.input_ids[idx],dtype=torch.long),
                'attention_mask': torch.as_tensor(self.attention_mask[idx],dtype=torch.float32),
                'labels': torch.as_tensor(self.labels[idx],dtype=torch.long)}

    def view(self,idx:int)->str:
        """:param idx(int): returns readable format of single input_ids and labels in the form of readable text"""
        return ' '.join([''.join(x) for x in list(zip(tokenizer.convert_ids_to_tokens(self.input_ids[idx]),[id2tag[x] for x in self.labels[idx].tolist()]))])

    def __len__(self)->int:
        return len(self.labels)

In [6]:
class config:
  def __init__(self):
    self.max_len=128
    self.overlap = 126
    self.train_batch_size = 4
    self.dev_batch_size = 4
    self.gpu_device = 'cpu' # 'cuda:0' #
    self.freeze_epochs = 20
    self.freeze_lr = 1e-4
    self.unfreeze_epochs = 20
    self.unfreeze_layers = 6
    self.unfreeze_lr = 1e-5
    self.base_model_path = 'distilbert-base-uncased'
    self.train_dataset = '/content/gdrive/MyDrive/ASR/ted_talks_processed.train.pt'
    self.dev_dataset = '/content/gdrive/MyDrive/ASR/ted_talks_processed.dev.pt'
    self.alpha = 0.8
    self.hidden_dropout_prob = 0.3
    self.embedding_dim = 768
    self.num_labels = 10
    self.hidden_dim = 128
    self.self_adjusting = True
    self.square_denominator = False
    self.use_crf = False
    self.model_name = 'bertcrf'
    self.model_path = "/content/gdrive/MyDrive/ASR/logs/models/"
config = transformers.configuration_utils.PretrainedConfig.from_dict(config().__dict__)

In [5]:
device = torch.device(config.gpu_device) if torch.cuda.is_available() else torch.device('cpu')
train_dataset=PunctuationDataset(**torch.load(config.train_dataset,map_location=device)[:])
# dev_dataset=PunctuationDataset(**torch.load(config.dev_dataset,map_location=device)[:])
train_dataset.view(-1000)

"[CLS]  turns  activism  into  terrorism  if  it  causes  a  loss  of  profits. now  most  people  never  even  heard  about  this  law, including  members  of  congress. less  than  one  percent  were  in  the  room  when  it  passed  the  house. the  rest  were  outside  at  a  new  memorial. they  were  praising  dr. king  as  his  style  of  activism  was  branded  as  terrorism  if  done  in  the  name  of  animals  or  the  environment. supporters  say  laws  like  this  are  needed  for  the  ex  ##tre  ##mist  ##s: the  van  ##dal  ##s, the  arson  ##ists, the  radicals. but  right  now, companies  like  trans  ##cana  ##da  are  briefing  police  in  presentations  like  this  one  about  how  to  prose  ##cute  non  ##vio  ##lent  protesters  as  terrorists. the  fbi  '  s  training  documents  on  eco- [SEP] "

In [7]:
train_dataloader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, num_workers=4)
# dev_dataloader=torch.utils.data.DataLoader(dev_dataset, batch_size=config.dev_batch_size, num_workers=2)
{x:y.shape for x,y in next(iter(train_dataloader)).items()} #(batch_size, seq_len)

{'attention_mask': torch.Size([4, 128]),
 'input_ids': torch.Size([4, 128]),
 'labels': torch.Size([4, 128])}

In [32]:
#b_s = 4, s_l = 128, h_d = 768
# bert=transformers.BertModel.from_pretrained(config.base_model_path)
# input_ids, attention_mask, labels = next(iter(train_dataloader)).values() # (batch_size, seq_len) * 3
# bo=bert(input_ids, attention_mask) # => last_hidden_state (b_s, s_l, h_d 768), pooler_output (b_s, h_d)
# dropout = nn.Dropout(config.hidden_dropout_prob) # (b_s, s_l, h_d)
# sequence_output=dropout(bo[0])
# fcl = nn.Linear(config.embedding_dim, config.num_labels)
# fcl_output=fcl(sequence_output) #(b_s, s_l, h_d) -> (b_s, s_l, num_labels)

# bert(next)

In [151]:
DiceLoss()(fcl_output,labels)

tensor(0.9575, grad_fn=<MeanBackward0>)

In [148]:
# dice_loss(fcl_output,labels,attention_mask)
#b_s,s_l,n_l
# pred_soft=torch.softmax(fcl_output,-1)#,torch.softmax(fcl_output[0,0,:],-1) #apply softmax to each token
# target_one_hot=F.one_hot(labels,num_classes=config.num_labels)#,labels[0,:5] (b_s, s_l) -> (b_s, s_l, n_l)
# pred_factor=((1-pred_soft) ** config.alpha) if config.self_adjusting else 1
# pred_prod=pred_factor*pred_soft*target_one_hot
# sum(pred_prod,0).shape
# smooth = 1e-8
# intersection=torch.sum(pred_prod,1)
intersection.shape,pred_prod.shape
# cardinality =torch.sum(pred_factor*pred_soft + target_one_hot, 1)
# dice_score=1-2*(intersection+smooth)/(cardinality+smooth)
# dice_score[0,:]
# weight=[0,0,0,0,0,0,0,1,0,0]
# (dice_score[:2,:]*torch.tensor(torch.tensor(weight))).shape
### torch.gather(pred_soft[0,:5],-1,index=labels[0,:5].unsqueeze(-1)) #returns the probability for the most likely example is this really needed?
# target_one_hot.shape
# labels.unsqueeze(2).shape,labels.shape
# pred_soft[0,:5].shape

(torch.Size([4, 10]), torch.Size([4, 128, 10]))

In [149]:
class DiceLoss(nn.Module):
    r"""
    Creates a criterion that optimizes a multi-class Self-adjusting Dice Loss
    ("Dice Loss for Data-imbalanced NLP Tasks" paper)
    Args:
        alpha (float): a factor to push down the weight of easy examples
        gamma (float): a factor added to both the nominator and the denominator for smoothing purposes
    """
    def __init__(self,
                 smooth: Optional[float] = 1e-8,
                #  square_denominator: Optional[bool] = False,
                 self_adjusting: Optional[bool] = False,
                #  with_logits: Optional[bool] = True,
                 reduction: Optional[str] = "mean",
                 alpha: float = 1.0,
                #  ignore_index: int = -100,
                 weight=1, #int or list
                 ) -> None:
        super(DiceLoss, self).__init__()
        # self.ignore_index = ignore_index
        self.reduction = reduction
        self.self_adjusting = self_adjusting
        self.alpha = alpha
        self.smooth = smooth
        # self.square_denominator = square_denominator
        self.weight=weight
    def forward(self,
                pred: torch.Tensor,
                target: torch.Tensor,
                # mask: Optional[torch.Tensor] = None,
                num_classes: int = 10,
                ) -> torch.Tensor:
        pred_soft = torch.softmax(pred,-1) #(batch_size,seq_len,num_labels)->(batch_size,seq_len,num_labels), sum along num_labels to 1
        target_one_hot=F.one_hot(target,num_classes=num_classes) #(b_s, s_l) -> (b_s, s_l, n_l)
        pred_factor = ((1-pred_soft) ** self.alpha) if self.self_adjusting else 1
        # if mask is not None:
        #     mask = mask.view(-1).float()
        #     pred_soft = pred_soft * mask
        #     target_one_hot = target_one_hot * mask
        intersection = torch.sum(pred_factor * pred_soft * target_one_hot, 1) # (b_s,s_l,n_l)->(b_s,n_l)
        cardinality = torch.sum(pred_factor * pred_soft + target_one_hot, 1)  # (b_s,s_l,n_l)->(b_s,n_l)
        dice_score = 1. - 2. * (intersection + self.smooth) / (cardinality + self.smooth) * torch.tensor(self.weight)
        if self.reduction == "mean":
            return dice_score.mean()
        elif self.reduction == "sum":
            return dice_score.sum()
        elif self.reduction == "none" or self.reduction is None:
            return dice_score
        else:
            raise NotImplementedError(f"Reduction `{self.reduction}` is not supported.")
    def __str__(self):
        return f"Dice Loss smooth:{self.smooth}"

# def dice_loss(output, target, mask, weight=1):
#     lfn = DiceLoss(square_denominator=config.square_denominator,self_adjusting=config.self_adjusting,alpha=config.alpha, weight=weight)
#     active_loss = mask.view(-1) == 1 # (batch_size,seq_len)->(batch_size*seq_len)
#     active_logits = output.view(-1, config.num_labels) # (batch_size, seq_len ,num_labels) -> (batch_size*seq_len,num_labels)
#     active_labels = torch.where(
#         active_loss,
#         target.view(-1),
#         torch.tensor(-100).type_as(target)
#     ) # -100 if out of mask else target (batch_size*seq_len)
# #     assert 0, f'{target.size()},{active_labels.size()}'
#     loss = lfn(active_logits, active_labels,num_classes=config.num_labels)
#     print('loss',loss)
#     return loss

In [None]:
### engine.py
# from tqdm import tqdm
def train_step(trainer,batch):
    model.train()
    optimizer.zero_grad()
    batch=[_data.to(device) for _data in batch]
    _, loss = model(batch)
    loss.backward()
    optimizer.step()
    scheduler.step()
    return loss.item()

def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()
    final_loss = 0
    for data in tqdm(data_loader, total=len(data_loader)):
        if torch.cuda.is_available(): data=[_data.to(device) for _data in data]
        optimizer.zero_grad()
        _, loss = model(data)
        loss.backward()
        optimizer.step()
        scheduler.step()
        final_loss += loss.item()
    return final_loss / len(data_loader)


def eval_fn(data_loader, model, device):
    model.eval()
    final_loss = 0
    for data in tqdm(data_loader, total=len(data_loader)):
        if torch.cuda.is_available(): data=[_data.to(device) for _data in data]
        punct, loss = model(data)
        final_loss += loss.item()
    return final_loss / len(data_loader)


In [None]:
model = BertCRFModel(num_punct=10, embedding_dim=config.EMBEDDING_DIM, hidden_dim=config.HIDDEN_DIM, use_crf=config.USE_CRF)
for i,param in enumerate(model.bert.parameters()):
    param.requires_grad = False
model.to(device)
optimizer = AdamW(optimizer_parameters, lr=config.FREEZE_LEARNING_RATE)
num_train_steps = train_dataset.tensors[0].size()[0] / config.TRAIN_BATCH_SIZE * config.UNFREEZE_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_train_steps
)
# optimizer = SWA(base_opt)

trainer = Engine(train_step)
val_metrics = {
    "precision": Precision(),
    "recall": Recall(),
#     "Dice": DiceCoefficient(cm=),
    "F1": Fbeta(1),
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics)
def log_metrics(engine, title):
    print("Epoch: {} - {} accuracy: {:.2f}"
           .format(trainer.state.epoch, title, engine.state.metrics["acc"]))

@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
    with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):
        evaluator.run(train_dataloader)

    with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "dev"):
        evaluator.run(dev_dataloader)

trainer.run(train_dataloader, max_epochs=100)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing BertModel: ['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'd

Current run is terminating due to exception: too many values to unpack (expected 2).
Engine run is terminating due to exception: too many values to unpack (expected 2).
Engine run is terminating due to exception: too many values to unpack (expected 2).


ValueError: too many values to unpack (expected 2)

In [None]:
torch.utils.data.DataLoader??