In [7]:
storageName = 'university'

if storageName == 'paperspace': 
    guy_folder = "/notebooks/"
elif storageName == 'colab':
    guy_folder = "/content/"
elif storageName == 'university':
    guy_folder = '/vol/scratch/guy/'
    
    
cache_dir = guy_folder+"/cache/transformer_cache"

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers import AdamW
%matplotlib inline

In [63]:
def _introduceBinLinear(m_):
    children = m_.named_children()
    for name, m in children:
        if name in ['k_lin', 'q_lin', 'v_lin']:
            m_.__setattr__(name, BinLinear(m.weight, m.bias))




class Binarizer(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        threshold = Binarizer.threshold
        output = (input > threshold).float()        
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

    @staticmethod
    def setThreshold(threshold):
        Binarizer.threshold = threshold
    

    
Binarizer.setThreshold(0.5)



class BinLinear(nn.Module):
    def __init__(self, originalWeight, originalBias, initialSparsity = .7):
        super().__init__()
        shape = originalWeight.shape
        self.originalWeight = originalWeight
        self.originalBias = originalBias

        self.rawMask = nn.Parameter(torch.rand(*shape) * self._sparsityMask(initialSparsity, shape))
    
    def forward(self, x):
        w = self.originalWeight * Binarizer.apply(self.rawMask)
        output = x @ (w.t()) 
        if self.originalBias is not None:
            output += self.originalBias
        return output
    
    def _sparsityMask(self, initialSparsity, shape):
        return (torch.rand(*shape) > initialSparsity).float()
    
    def extra_repr(self):
        rep = '{},{}'.format(*self.rawMask.shape)
        if self.originalBias is not None:
            rep += ', bias = True'
        return rep


In [64]:
bert = AutoModel.from_pretrained("distilbert-base-cased", cache_dir = cache_dir)

In [65]:
bert.requires_grad_(False)
bert.apply(_introduceBinLinear);