## Initialization

In [9]:
storageName = 'university'

if storageName == 'paperspace': 
    guy_folder = "/notebooks/"
elif storageName == 'colab':
    guy_folder = "/content/"
elif storageName == 'university':
    guy_folder = '/vol/scratch/guy/'
    
    
cache_dir = guy_folder+"/cache/transformer_cache"

In [None]:
%pip install --no-cache-dir datasets

## Imports

In [18]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers import AdamW
from datasets import load_dataset
from functools import partial
from tqdm import tqdm
%matplotlib inline

In [19]:
%run ./utils.ipynb

## Functions

In [22]:
modelName = "distilbert-base-cased"

def _introduceBinLinear(m_, with_fc):
    children = m_.named_children()
    lin_names = ['k_lin', 'q_lin', 'v_lin', 'out_lin']
    if with_fc:
         lin_names += ['lin1', 'lin2']
    for name, m in children:
        if name in lin_names:
            m_.__setattr__(name, BinLinear(m.weight, m.bias))

In [23]:
class Binarizer(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, threshold):
        output = (input > threshold).float()        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


class BinLinear(nn.Module):
    threshold = .5
    def __init__(self, originalWeight, originalBias, initialSparsity = .7, threshold = None):
        # Notice positivity of the mask is not warranated
        super().__init__()
        shape = originalWeight.shape
        self.threshold = threshold
        self.originalWeight = originalWeight
        self.originalBias = originalBias

        self.rawMask = nn.Parameter(torch.rand(*shape) * self._sparsityMask(initialSparsity, shape))
    
    def forward(self, x):
        w = self.originalWeight * Binarizer.apply(self.rawMask, self.getThreshold())
        output = x @ (w.t()) 
        if self.originalBias is not None:
            output += self.originalBias
        return output
    
    def _sparsityMask(self, initialSparsity, shape):
        return (torch.rand(*shape) > initialSparsity).float()
    
    def extra_repr(self):
        rep = '{}, {}'.format(*self.rawMask.t().shape)
        if self.originalBias is not None:
            rep += ', bias = True'
        return rep
    
    def getThreshold(self):
        if self.threshold is None:
            return BinLinear.threshold
        else:
            return self.threshold
    

class BinBert(nn.Module):
    def __init__(self, n_classes = None, 
                 with_fc = True, pool = False):
        super().__init__()
        assert((not pool) or (n_classes is not None))
        self.pool = pool
        self.body = AutoModel.from_pretrained(modelName, cache_dir = cache_dir)
        self.body.requires_grad_(False)
        self.body.apply(partial(_introduceBinLinear, with_fc = with_fc));
        if n_classes is not None:
            shape = (768, n_classes)
            self.label_fc = BinLinear(nn.Parameter(torch.rand(*shape).t()),
                                    nn.Parameter(torch.rand(shape[1])))
    def forward(self, x):
        output, = self.body(x)
        if hasattr(self, 'label_fc'):
            output = self.label_fc(output)
            if self.pool: 
                output = output[:, 0]
        return output

    def setMaskTrainability(self, requires_grad):
        def _changeMaskTrainability(m):
            if isinstance(m, BinLinear):
                m.rawMask.requires_grad_(requires_grad)
        self.apply(_changeMaskTrainability)
        
    def iterMasks(self):
        return map(lambda x: x[1], filter(lambda x: x[0].endswith('.rawMask'), self.named_parameters()))
    
    def _randomizeSimpleMask(self, mask, sparsity):
        with torch.no_grad():
            mask.uniform_()
            mask.set_((mask > sparsity).float())
        
    def randomizeSimpleMasks(self, sparsity):
        for mask in self.iterMasks():
            self._randomizeSimpleMask(mask, sparsity)
    
    
    def getLinear(self, i,  linName):
        assert((type(i) == int) and (linName in ['k_lin', 'q_lin', 'v_lin', 'out_lin', 'lin1', 'lin2']))
        if linName in ['lin1', 'lin2']:
            linName = 'ffn.' + linName
        else:
            linName = 'attention.' + linName
        linName = "." + str(i) + "." + linName
        return next(map(lambda x: x[1], 
                        filter(lambda x: (x[0].endswith(linName)), self.named_modules())
                       ))
    
    
    

## Main

In [24]:
n_classes = 2
tokenizer = AutoTokenizer.from_pretrained(modelName, cache_dir = cache_dir)
binBert = BinBert(n_classes, pool = True, with_fc = False)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




In [25]:
binBert.randomizeSimpleMasks(0.8)

In [15]:
batch_size = 1
train_ds = TextDataset('imdb', 
                       tokenizer = tokenizer, inputCol = 'text', targetCol = 'label')
train_dataloader = DataLoader(train_ds, batch_size = batch_size, 
                              shuffle = True)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(binBert.parameters(), lr = 1e-5)
for epoch in range(10):
    running_loss = 0
    pbar = tqdm(train_dataloader, position = 0, leave = True)
    for i, (x, y) in enumerate(pbar):
        yhat = model(**inputs)
        loss = criterion(yhat, y)
        running_loss += loss.item() / batch_size
        running_acc += (yhat.argmax(-1) == y).sum().item() / batch_size
        pbar.set_postfix_str("Loss: {} Acc: {}".format(running_loss / (1+i), 
                                                       running_acc / (i + 1)
                                                      ))
        loss.backward()
        optimizer.step()

TypeError: __init__() missing 3 required positional arguments: 'ds', 'tokenizer', and 'targetCol'

## Help

In [8]:
import gc
gc.collect()

20

In [26]:
[print("{}: {}".format(k, v.shape)) for k,v in binBert.named_parameters() if v.requires_grad];

label_fc.originalWeight: torch.Size([2, 768])
label_fc.originalBias: torch.Size([2])
