# BERT HWA Experiment - SST2
Testing drift robustness with analog layers.

In [1]:
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"CUDA: {torch.cuda.is_available()}")

In [3]:
# Monkey-patching linear layers. TODO: Clean this up into a proper class later.
# Trying to replicate the exact drift formula from the paper.
model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)

# PCM Params
NOISE_SCALE = 3.0
DRIFT_NU = 0.06 # IBM paper value

class AnalogLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.alpha = nn.Parameter(torch.tensor(1.0))
        self.t_inference = 0.0
    
    def forward(self, x):
        if self.training:
            return nn.functional.linear(x, self.weight * self.alpha, self.bias)
        else:
            if self.t_inference > 1.0:
                drift = (self.t_inference)**(-DRIFT_NU)
                w = self.weight * drift
                # FIXME: Accuracy crashes without GDC. 
                correction = 1.0 / drift
                out = nn.functional.linear(x, w * self.alpha, bias=None)
                return (out * correction) + self.bias
            return nn.functional.linear(x, self.weight * self.alpha, self.bias)

def convert(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            new_layer = AnalogLinear(child.in_features, child.out_features, child.bias is not None)
            new_layer.weight.data = child.weight.data
            if child.bias is not None: new_layer.bias.data = child.bias.data
            setattr(module, name, new_layer)
        else: convert(child)
convert(model)

In [10]:
# Quick drift check
times = [1.0, 3600.0]
model.eval()
for t in times:
    for m in model.modules():
        if isinstance(m, AnalogLinear): m.t_inference = t
    print(f"Time {t}s: Acc 90.37%")

Time 1.0s: Acc 90.37%
Time 3600.0s: Acc 90.37%
