In [1]:
import math
import torch

torch.set_num_threads(1)

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification

bert_model = BertForSequenceClassification.from_pretrained('gchhablani/bert-base-cased-finetuned-qnli')
bert_model.eval()

# Access the model's weights
weights = bert_model.state_dict()

# Modify the weights or perform any operation you desire
# Example: Print the shape of each weight tensor
o = 0
b = 0
for name, weight in weights.items():
    if "weight" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        o += p
    elif "bias" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        b += p
    else:
        print(f"else {name}: {weight.size()}")
print(f"n_weight={o}, n_bias={b}, n_param={o+b}")


bert.embeddings.word_embeddings.weight: torch.Size([28996, 768])
bert.embeddings.position_embeddings.weight: torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight: torch.Size([2, 768])
bert.embeddings.LayerNorm.weight: torch.Size([768])
bert.embeddings.LayerNorm.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias: torch.Size([768

In [3]:
tokenizer = AutoTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-qnli")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = bert_model(**inputs)
realput = outputs
print(realput)

SequenceClassifierOutput(loss=None, logits=tensor([[-0.5781,  0.1473]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [4]:
import codecs

def load_tsv(data_file, max_seq_len, delimiter='\t'):
    '''Load a tsv '''
    sentences = []
    targets = []
    with codecs.open(data_file, 'r', 'utf-8') as data_fh:
        for _ in range(1):
            data_fh.readline()
        for row in data_fh:
            row = row.strip().split(delimiter)
            sentences.append(tokenizer(row[1], row[2], return_tensors="pt"))
            targets.append(1*(row[3] == "not_entailment"))
    return sentences, targets

data, targets = load_tsv("GLUE-baselines/glue_data/QNLI/dev.tsv", 1024)

In [5]:
count = 0
total = 0
for label in range(100):
    outputs = bert_model(**data[label])
    count += targets[label] == outputs.logits.argmax()
    total += 1
count / total

tensor(0.9100)

In [6]:
class Attention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Attention, self).__init__()

        assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.search_dim = embed_dim // num_heads

        self.key = torch.nn.Linear(embed_dim, embed_dim)
        self.value = torch.nn.Linear(embed_dim, embed_dim)
        self.query = torch.nn.Linear(embed_dim, embed_dim)
        self.proj = torch.nn.Linear(embed_dim, embed_dim)


    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

        attn = q.matmul(k_t) / math.sqrt(q.size(-1))
        attn = attn.softmax(dim=-1)
        y = attn.matmul(v)
        y = y.transpose(1, 2)
        y = y.reshape(batch_size, seq_len, self.embed_dim)
        return y


class BertBlock(torch.nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(BertBlock, self).__init__()
        embed_dim = embed_dim
        self.ln1 = torch.nn.LayerNorm(embed_dim)
        self.ln2 = torch.nn.LayerNorm(embed_dim)
        self.attn = Attention(embed_dim, num_heads)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(embed_dim, embed_dim * 4),
            torch.nn.GELU(),
            torch.nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x))
        x = self.ln2(x + self.ff(x))
        return x

class Bert(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
        super(Bert, self).__init__()
        self.full = full
        if full:
            self.word_embed = torch.nn.Embedding(vocab_size, embed_dim)
            self.pos_embed = torch.nn.Parameter(torch.zeros(1, seq_len, embed_dim))
            self.type_embed = torch.nn.Embedding(2, embed_dim)
            self.ln = torch.nn.LayerNorm(embed_dim)
        self.blocks = torch.nn.Sequential(
            *[BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
        )
        if full:
            self.pooler = torch.nn.Linear(embed_dim, embed_dim)
            self.classifier = torch.nn.Linear(embed_dim, 2)

    def forward(self, x, target=None):
        if self.full:
            word_embedding = self.word_embed(x["input_ids"])
            pos_embedding = self.pos_embed[:, :x["input_ids"].size()[1], :]
            type_embedding = self.type_embed(x["token_type_ids"])
            x = word_embedding + pos_embedding + type_embedding
            x = self.ln(x)
        x = self.blocks(x)
        if self.full:
            x = self.pooler(x[:, 0, :])
            x = x.tanh()
            x = self.classifier(x)
        return x

plain_model = Bert(768, 12, 12, 30522, 1024, True) # bert base 13.5s

In [7]:
plain_model.word_embed.weight = torch.nn.Parameter(weights["bert.embeddings.word_embeddings.weight"])
plain_model.pos_embed.data = torch.nn.Parameter(weights["bert.embeddings.position_embeddings.weight"][None, :, :])
plain_model.type_embed.weight = torch.nn.Parameter(weights["bert.embeddings.token_type_embeddings.weight"])
plain_model.ln.weight = torch.nn.Parameter(weights["bert.embeddings.LayerNorm.weight"])
plain_model.ln.bias = torch.nn.Parameter(weights["bert.embeddings.LayerNorm.bias"])
for m in range(len(plain_model.blocks._modules)):
    layer = "bert.encoder.layer."
    plain_model.blocks._modules[str(m)].attn.query.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.self.query.weight"])
    plain_model.blocks._modules[str(m)].attn.query.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.self.query.bias"])
    plain_model.blocks._modules[str(m)].attn.key.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.self.key.weight"])
    plain_model.blocks._modules[str(m)].attn.key.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.self.key.bias"])
    plain_model.blocks._modules[str(m)].attn.value.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.self.value.weight"])
    plain_model.blocks._modules[str(m)].attn.value.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.self.value.bias"])
    plain_model.blocks._modules[str(m)].attn.proj.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.output.dense.weight"]) # .t()
    plain_model.blocks._modules[str(m)].attn.proj.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.output.dense.bias"])
    plain_model.blocks._modules[str(m)].ln1.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.output.LayerNorm.weight"])
    plain_model.blocks._modules[str(m)].ln1.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.output.LayerNorm.bias"])
    plain_model.blocks._modules[str(m)].ff._modules['0'].weight = torch.nn.Parameter(weights[layer+str(m)+".intermediate.dense.weight"]) # .t()
    plain_model.blocks._modules[str(m)].ff._modules['0'].bias = torch.nn.Parameter(weights[layer+str(m)+".intermediate.dense.bias"])
    plain_model.blocks._modules[str(m)].ff._modules['2'].weight = torch.nn.Parameter(weights[layer+str(m)+".output.dense.weight"]) # .t()
    plain_model.blocks._modules[str(m)].ff._modules['2'].bias = torch.nn.Parameter(weights[layer+str(m)+".output.dense.bias"])
    plain_model.blocks._modules[str(m)].ln2.weight = torch.nn.Parameter(weights[layer+str(m)+".output.LayerNorm.weight"])
    plain_model.blocks._modules[str(m)].ln2.bias = torch.nn.Parameter(weights[layer+str(m)+".output.LayerNorm.bias"])
plain_model.pooler.weight = torch.nn.Parameter(weights["bert.pooler.dense.weight"]) # .t()
plain_model.pooler.bias = torch.nn.Parameter(weights["bert.pooler.dense.bias"])
plain_model.classifier.weight = torch.nn.Parameter(weights["classifier.weight"]) # .t()
plain_model.classifier.bias = torch.nn.Parameter(weights["classifier.bias"])

In [8]:
# Load data to Bob
print('loading data')
# Classify the encrypted data
print("forward")
x = {}
x['input_ids'] = data[0]["input_ids"]
x['token_type_ids'] = data[0]["token_type_ids"]
# x['input_ids'] = torch.tensor([1])
# x['token_type_ids'] = torch.tensor([0, 1, 0, 1])
word_embedding = plain_model.word_embed(x["input_ids"])
pos_embedding = plain_model.pos_embed[:, :x["input_ids"].size()[1], :]
type_embedding = plain_model.type_embed(x["token_type_ids"])
x = word_embedding + pos_embedding + type_embedding
x = plain_model.ln(x)
# v = '0'
# x = plain_model.blocks._modules[v].attn(x)
x = plain_model.blocks._modules['0'](x)
x = plain_model.blocks._modules['1'](x)
x = plain_model.blocks._modules['2'](x)
x = plain_model.blocks._modules['3'](x)
x = plain_model.blocks._modules['4'](x)
x = plain_model.blocks._modules['5'](x)
x = plain_model.blocks._modules['6'](x)
x = plain_model.blocks._modules['7'](x)
x = plain_model.blocks._modules['8'](x)
x = plain_model.blocks._modules['9'](x)
x = plain_model.blocks._modules['10'](x)
x = plain_model.blocks._modules['11'](x)
x = plain_model.pooler(x[:, 0, :])
x = x.tanh()
x = plain_model.classifier(x)

# x = plain_model(x)
print('output')
# Compute the accuracy
output = x
print(f"{output.shape=}")
print(f"{output=}")
print(f"{output.max()} {output.min()}")
print(f"{output.mean()} {output.var()}")

loading data
forward
output
output.shape=torch.Size([1, 2])
output=tensor([[-0.7974,  0.5348]], grad_fn=<AddmmBackward0>)
0.5348405241966248 -0.7973539233207703
-0.13125669956207275 0.8873710036277771


In [9]:
count = 0
total = 0
for label in range(10):
    x = {}
    x['input_ids'] = data[label]["input_ids"]
    x['token_type_ids'] = data[label]["token_type_ids"]
    outputs = plain_model(x)
    print(outputs)
    count += targets[label] == outputs.argmax()
    total += 1
count / total

tensor([[-0.7974,  0.5348]], grad_fn=<AddmmBackward0>)
tensor([[-0.6113,  0.3770]], grad_fn=<AddmmBackward0>)
tensor([[-0.6552,  0.3993]], grad_fn=<AddmmBackward0>)
tensor([[-0.2674,  0.0475]], grad_fn=<AddmmBackward0>)
tensor([[-0.0812,  0.0456]], grad_fn=<AddmmBackward0>)
tensor([[-0.1108,  0.0706]], grad_fn=<AddmmBackward0>)
tensor([[-0.5361,  0.2477]], grad_fn=<AddmmBackward0>)
tensor([[-1.0271,  0.7352]], grad_fn=<AddmmBackward0>)
tensor([[-0.0689, -0.0279]], grad_fn=<AddmmBackward0>)
tensor([[-0.5986,  0.4556]], grad_fn=<AddmmBackward0>)


tensor(0.7000)

In [10]:
count = 0
total = 0
for i in range(10):
    outputs = bert_model(**data[i])
    print(outputs.logits)
    count += targets[i] == outputs.logits.argmax()
    total += 1
count / total

tensor([[ 2.0014, -1.6264]], grad_fn=<AddmmBackward0>)
tensor([[-3.5682,  2.9418]], grad_fn=<AddmmBackward0>)
tensor([[-2.8672,  2.4340]], grad_fn=<AddmmBackward0>)
tensor([[ 3.1729, -2.6355]], grad_fn=<AddmmBackward0>)
tensor([[-2.2055,  1.8402]], grad_fn=<AddmmBackward0>)
tensor([[-3.6816,  3.0014]], grad_fn=<AddmmBackward0>)
tensor([[-1.3215,  1.2353]], grad_fn=<AddmmBackward0>)
tensor([[-3.7192,  2.9996]], grad_fn=<AddmmBackward0>)
tensor([[-3.8530,  3.0773]], grad_fn=<AddmmBackward0>)
tensor([[ 3.5189, -2.9471]], grad_fn=<AddmmBackward0>)


tensor(1.)