In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers import AdamW
%matplotlib inline

In [2]:
class BinarizedLinearFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, weight):
        
        threshold = BinarizedLinearFunction.threshold
        
        ctx.save_for_backward(input, weight)
        output = input.mm((weight.t() > threshold).float())
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        grad_input = grad_weight = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        
        return grad_input, grad_weight

    @staticmethod
    def setThreshold(threshold):
        BinarizedLinearFunction.threshold = threshold
    
BinLinear = BinarizedLinearFunction.apply
BinarizedLinearFunction.setThreshold(0.5)

In [None]:
BinLinear()