In [18]:
import math
import torch
import torch.nn as nn

torch.set_num_threads(1)

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification

bert_model = BertForSequenceClassification.from_pretrained('gchhablani/bert-base-cased-finetuned-qnli')
bert_model.eval()

# Access the model's weights
weights = bert_model.state_dict()

# Modify the weights or perform any operation you desire
# Example: Print the shape of each weight tensor
o = 0
b = 0
for name, weight in weights.items():
    if "weight" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        o += p
    elif "bias" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        b += p
    else:
        print(f"else {name}: {weight.size()}")
print(f"n_weight={o}, n_bias={b}, n_param={o+b}")


bert.embeddings.word_embeddings.weight: torch.Size([28996, 768])
bert.embeddings.position_embeddings.weight: torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight: torch.Size([2, 768])
bert.embeddings.LayerNorm.weight: torch.Size([768])
bert.embeddings.LayerNorm.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias: torch.Size([768

In [3]:
tokenizer = AutoTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-qnli")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = bert_model(**inputs)
realput = outputs
print(realput)

SequenceClassifierOutput(loss=None, logits=tensor([[-0.5781,  0.1473]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [4]:
import codecs

def load_tsv(data_file, max_seq_len, delimiter='\t'):
    '''Load a tsv '''
    sentences = []
    targets = []
    with codecs.open(data_file, 'r', 'utf-8') as data_fh:
        for _ in range(1):
            data_fh.readline()
        for row in data_fh:
            row = row.strip().split(delimiter)
            sentences.append(tokenizer(row[1], row[2], return_tensors="pt"))
            targets.append(1*(row[3] == "not_entailment"))
    return sentences, targets

data, targets = load_tsv("GLUE-baselines/glue_data/QNLI/dev.tsv", 1024)

In [5]:
count = 0
total = 0
for label in range(100):
    outputs = bert_model(**data[label])
    count += targets[label] == outputs.logits.argmax()
    total += 1
count / total

tensor(0.9100)

In [None]:

# class Attention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(Attention, self).__init__()

#         assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

#         self.embed_dim = embed_dim
#         self.num_heads = num_heads
#         self.search_dim = embed_dim // num_heads

#         self.key = nn.Linear(embed_dim, embed_dim)
#         self.value = nn.Linear(embed_dim, embed_dim)
#         self.query = nn.Linear(embed_dim, embed_dim)
#         self.proj = nn.Linear(embed_dim, embed_dim)


#     def forward(self, x):
#         batch_size = x.shape[0]
#         seq_len = x.shape[1]

#         k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
#         v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
#         q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

#         attn = q.matmul(k_t) / math.sqrt(q.size(-1))
#         attn = attn.softmax(dim=-1)
#         y = attn.matmul(v)
#         y = y.transpose(1, 2)
#         y = y.reshape(batch_size, seq_len, self.embed_dim)
#         return y

# class BertBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(BertBlock, self).__init__()
#         embed_dim = embed_dim
#         self.ln1 = nn.LayerNorm(embed_dim)
#         self.ln2 = nn.LayerNorm(embed_dim)
#         self.attention = Attention(embed_dim, num_heads)
#         self.ff = nn.Sequential(
#             nn.Linear(embed_dim, embed_dim * 4),
#             nn.GELU(),
#             nn.Linear(embed_dim * 4, embed_dim),
#         )

#     def forward(self, x):
#         x = self.ln1(x + self.attn(x))
#         x = self.ln2(x + self.ff(x))
#         return x

# class Bert(nn.Module):
#     def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len):
#         super(Bert, self).__init__()
#         self.embeddings = BertEmbeddings(vocab_size, embed_dim, seq_len)
#         self.encoder = nn.ModuleList(
#             [BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
#         )
#         self.pooler = nn.Linear(embed_dim, embed_dim)
#         self.classifier = nn.Linear(embed_dim, 2)

#     def forward(self, x):
#         if self.full:
#             word_embedding = self.word_embed(x["input_ids"])
#             pos_embedding = self.pos_embed[:, :x["input_ids"].size()[1], :]
#             type_embedding = self.type_embed(x["token_type_ids"])
#             x = word_embedding + pos_embedding + type_embedding
#             x = self.ln(x)
#         x = self.blocks(x)
#         if self.full:
#             x = self.pooler(x[:, 0, :])
#             x = x.tanh()
#             x = self.classifier(x)
#         return x

# class LLM:
#     def __init__(self, model):
#         self.bert = model

#     def forward(self, x):
#         return self.bert(x)


In [55]:
class BertEmbeddings(nn.Module):
    def __init__(self, vocab_size, emb_size, max_seq_length):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, emb_size)
        self.position_embeddings = nn.Embedding(max_seq_length, emb_size)
        self.token_type_embeddings = nn.Embedding(2, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, input_ids, token_type_ids):
        word_emb = self.word_embeddings(input_ids)
        pos_emb = self.position_embeddings.weight[:input_ids.size()[1], :]
        type_emb = self.token_type_embeddings(token_type_ids)

        emb = word_emb + pos_emb + type_emb
        emb = self.LayerNorm(emb)

        return emb

class BertSelfAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_size = emb_size // self.n_heads
        self.query = nn.Linear(emb_size, emb_size)
        self.key = nn.Linear(emb_size, emb_size)
        self.value = nn.Linear(emb_size, emb_size)

    def forward(self, emb):
        B, T, C = emb.shape  # batch size, sequence length, embedding size

        q = self.query(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = self.key(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        v = self.value(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)

        weights = q @ k.transpose(-2, -1) * self.head_size**-0.5

        weights = weights.softmax(dim=-1)

        emb_rich = weights @ v
        emb_rich = emb_rich.transpose(1, 2).contiguous().view(B, T, C)

        return emb_rich

class BertSelfOutput(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, emb_rich, emb):
        x = self.dense(emb_rich)
        x = x + emb
        out = self.LayerNorm(x)

        return out


class BertAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.self = BertSelfAttention(emb_size, n_heads)
        self.output = BertSelfOutput(emb_size)

    def forward(self, emb):
        emb_rich = self.self(emb)
        out = self.output(emb_rich, emb)

        return out

class BertIntermediate(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, 4 * emb_size)
        self.gelu = nn.GELU()

    def forward(self, att_out):
        x = self.dense(att_out)
        out = self.gelu(x)
        return out


class BertOutput(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(4 * emb_size, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, intermediate_out, att_out):
        x = self.dense(intermediate_out)
        x = x + att_out
        out = self.LayerNorm(x)

        return out

class BertLayer(nn.Module):
    def __init__(self, emb_size, n_heads ):
        super().__init__()
        self.attention = BertAttention(emb_size, n_heads)
        self.intermediate = BertIntermediate(emb_size)
        self.output = BertOutput(emb_size)

    def forward(self, emb):
        att_out = self.attention(emb)
        intermediate_out = self.intermediate(att_out)
        out = self.output(intermediate_out, att_out)

        return out

class BertEncoder(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(emb_size, n_heads) for i in range(12)])

    def forward(self, x):
        for l in self.layer:
            x = l(x)
        return x

class BertPooler(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, emb_size)
        self.tanh = nn.Tanh()

    def forward(self, encoder_out):
        pool_first_token = encoder_out[:, 0]
        out = self.dense(pool_first_token)
        out = self.tanh(out)
        return out

class BertModel(nn.Module):
    def __init__(self, vocab_size, emb_size, seq_len, n_heads):
        super().__init__()
        self.embeddings = BertEmbeddings(vocab_size, emb_size, seq_len)
        self.encoder = BertEncoder(emb_size, n_heads)
        self.pooler = BertPooler(emb_size)

    def forward(self, input_ids, token_type_ids):
        emb = self.embeddings(input_ids, token_type_ids)
        out = self.encoder(emb)
        pooled_out = self.pooler(out)
        return out, pooled_out

class BertForSequenceClassification(nn.Module):
    def __init__(self, vocab_size, emb_size, seq_len, n_heads):
        super().__init__()
        self.bert = BertModel(vocab_size, emb_size, seq_len, n_heads)
        self.classifier = nn.Linear(emb_size, 2)

    def forward(self, input_ids, token_type_ids):
        _, pooled_out = self.bert(input_ids, token_type_ids)
        logits = self.classifier(pooled_out)
        return logits

# plain_model = LLM(Bert(768, 12, 12, 30522, 1024))
plain_model = BertForSequenceClassification(28996, 768, 512, 12)
plain_model.load_state_dict(weights)

<All keys matched successfully>

In [17]:
plain_model.bert.embeddings.word_embeddings.weight = nn.Parameter(weights["bert.embeddings.word_embeddings.weight"])
plain_model.bert.embeddings.position_embeddings.weight = nn.Parameter(weights["bert.embeddings.position_embeddings.weight"][None, :, :])
plain_model.bert.embeddings.token_type_embeddings.weight = nn.Parameter(weights["bert.embeddings.token_type_embeddings.weight"])
plain_model.bert.embeddings.LayerNorm.weight = nn.Parameter(weights["bert.embeddings.LayerNorm.weight"])
plain_model.bert.embeddings.LayerNorm.bias = nn.Parameter(weights["bert.embeddings.LayerNorm.bias"])

In [7]:
# torch.load_state_dict
for m in range(len(plain_model.blocks._modules)):
    layer = "bert.encoder.layer."
    plain_model.blocks._modules[str(m)].attn.query.weight = nn.Parameter(weights[layer+str(m)+".attention.self.query.weight"])
    plain_model.blocks._modules[str(m)].attn.query.bias = nn.Parameter(weights[layer+str(m)+".attention.self.query.bias"])
    plain_model.blocks._modules[str(m)].attn.key.weight = nn.Parameter(weights[layer+str(m)+".attention.self.key.weight"])
    plain_model.blocks._modules[str(m)].attn.key.bias = nn.Parameter(weights[layer+str(m)+".attention.self.key.bias"])
    plain_model.blocks._modules[str(m)].attn.value.weight = nn.Parameter(weights[layer+str(m)+".attention.self.value.weight"])
    plain_model.blocks._modules[str(m)].attn.value.bias = nn.Parameter(weights[layer+str(m)+".attention.self.value.bias"])
    plain_model.blocks._modules[str(m)].attn.proj.weight = nn.Parameter(weights[layer+str(m)+".attention.output.dense.weight"]) # .t()
    plain_model.blocks._modules[str(m)].attn.proj.bias = nn.Parameter(weights[layer+str(m)+".attention.output.dense.bias"])
    plain_model.blocks._modules[str(m)].ln1.weight = nn.Parameter(weights[layer+str(m)+".attention.output.LayerNorm.weight"])
    plain_model.blocks._modules[str(m)].ln1.bias = nn.Parameter(weights[layer+str(m)+".attention.output.LayerNorm.bias"])
    plain_model.blocks._modules[str(m)].ff._modules['0'].weight = nn.Parameter(weights[layer+str(m)+".intermediate.dense.weight"]) # .t()
    plain_model.blocks._modules[str(m)].ff._modules['0'].bias = nn.Parameter(weights[layer+str(m)+".intermediate.dense.bias"])
    plain_model.blocks._modules[str(m)].ff._modules['2'].weight = nn.Parameter(weights[layer+str(m)+".output.dense.weight"]) # .t()
    plain_model.blocks._modules[str(m)].ff._modules['2'].bias = nn.Parameter(weights[layer+str(m)+".output.dense.bias"])
    plain_model.blocks._modules[str(m)].ln2.weight = nn.Parameter(weights[layer+str(m)+".output.LayerNorm.weight"])
    plain_model.blocks._modules[str(m)].ln2.bias = nn.Parameter(weights[layer+str(m)+".output.LayerNorm.bias"])
plain_model.pooler.weight = nn.Parameter(weights["bert.pooler.dense.weight"]) # .t()
plain_model.pooler.bias = nn.Parameter(weights["bert.pooler.dense.bias"])
plain_model.classifier.weight = nn.Parameter(weights["classifier.weight"]) # .t()
plain_model.classifier.bias = nn.Parameter(weights["classifier.bias"])

In [8]:
# Load data to Bob
print('loading data')
# Classify the encrypted data
print("forward")
x = {}
x['input_ids'] = data[0]["input_ids"]
x['token_type_ids'] = data[0]["token_type_ids"]
# x['input_ids'] = torch.tensor([1])
# x['token_type_ids'] = torch.tensor([0, 1, 0, 1])
word_embedding = plain_model.word_embed(x["input_ids"])
pos_embedding = plain_model.pos_embed[:, :x["input_ids"].size()[1], :]
type_embedding = plain_model.type_embed(x["token_type_ids"])
x = word_embedding + pos_embedding + type_embedding
x = plain_model.ln(x)
# v = '0'
# x = plain_model.blocks._modules[v].attn(x)
x = plain_model.blocks._modules['0'](x)
x = plain_model.blocks._modules['1'](x)
x = plain_model.blocks._modules['2'](x)
x = plain_model.blocks._modules['3'](x)
x = plain_model.blocks._modules['4'](x)
x = plain_model.blocks._modules['5'](x)
x = plain_model.blocks._modules['6'](x)
x = plain_model.blocks._modules['7'](x)
x = plain_model.blocks._modules['8'](x)
x = plain_model.blocks._modules['9'](x)
x = plain_model.blocks._modules['10'](x)
x = plain_model.blocks._modules['11'](x)
x = plain_model.pooler(x[:, 0, :])
x = x.tanh()
x = plain_model.classifier(x)

# x = plain_model(x)
print('output')
# Compute the accuracy
output = x
print(f"{output.shape=}")
print(f"{output=}")
print(f"{output.max()} {output.min()}")
print(f"{output.mean()} {output.var()}")

loading data
forward
output
output.shape=torch.Size([1, 2])
output=tensor([[-0.7974,  0.5348]], grad_fn=<AddmmBackward0>)
0.5348405241966248 -0.7973539233207703
-0.13125669956207275 0.8873710036277771


In [58]:
count = 0
total = 0
for label in range(100):
    x = {}
    x['input_ids'] = data[label]["input_ids"]
    x['token_type_ids'] = data[label]["token_type_ids"]
    outputs = plain_model(x['input_ids'], x['token_type_ids'])
    print(outputs)
    count += targets[label] == outputs.argmax()
    total += 1
count / total

tensor([[ 1.9990, -1.6248]], grad_fn=<AddmmBackward0>)
tensor([[-3.5682,  2.9418]], grad_fn=<AddmmBackward0>)
tensor([[-2.8683,  2.4348]], grad_fn=<AddmmBackward0>)
tensor([[ 3.1733, -2.6358]], grad_fn=<AddmmBackward0>)
tensor([[-2.2052,  1.8399]], grad_fn=<AddmmBackward0>)
tensor([[-3.6817,  3.0015]], grad_fn=<AddmmBackward0>)
tensor([[-1.3191,  1.2331]], grad_fn=<AddmmBackward0>)
tensor([[-3.7189,  2.9994]], grad_fn=<AddmmBackward0>)
tensor([[-3.8530,  3.0773]], grad_fn=<AddmmBackward0>)
tensor([[ 3.5190, -2.9471]], grad_fn=<AddmmBackward0>)
tensor([[-3.6168,  2.9147]], grad_fn=<AddmmBackward0>)
tensor([[ 3.6551, -2.9953]], grad_fn=<AddmmBackward0>)
tensor([[ 3.3681, -2.8038]], grad_fn=<AddmmBackward0>)
tensor([[-3.8698,  3.0741]], grad_fn=<AddmmBackward0>)
tensor([[-1.7479,  1.3016]], grad_fn=<AddmmBackward0>)
tensor([[ 3.7329, -3.1091]], grad_fn=<AddmmBackward0>)
tensor([[ 3.3685, -2.8042]], grad_fn=<AddmmBackward0>)
tensor([[ 3.1636, -2.6641]], grad_fn=<AddmmBackward0>)
tensor([[ 

tensor(0.9100)

In [59]:
count = 0
total = 0
for i in range(100):
    outputs = bert_model(**data[i])
    print(outputs.logits)
    count += targets[i] == outputs.logits.argmax()
    total += 1
count / total

tensor([[ 2.0014, -1.6264]], grad_fn=<AddmmBackward0>)
tensor([[-3.5682,  2.9418]], grad_fn=<AddmmBackward0>)
tensor([[-2.8672,  2.4340]], grad_fn=<AddmmBackward0>)
tensor([[ 3.1729, -2.6355]], grad_fn=<AddmmBackward0>)
tensor([[-2.2055,  1.8402]], grad_fn=<AddmmBackward0>)
tensor([[-3.6816,  3.0014]], grad_fn=<AddmmBackward0>)
tensor([[-1.3215,  1.2353]], grad_fn=<AddmmBackward0>)
tensor([[-3.7192,  2.9996]], grad_fn=<AddmmBackward0>)
tensor([[-3.8530,  3.0773]], grad_fn=<AddmmBackward0>)
tensor([[ 3.5189, -2.9471]], grad_fn=<AddmmBackward0>)
tensor([[-3.6169,  2.9147]], grad_fn=<AddmmBackward0>)
tensor([[ 3.6550, -2.9953]], grad_fn=<AddmmBackward0>)
tensor([[ 3.3679, -2.8037]], grad_fn=<AddmmBackward0>)
tensor([[-3.8698,  3.0740]], grad_fn=<AddmmBackward0>)
tensor([[-1.7491,  1.3029]], grad_fn=<AddmmBackward0>)
tensor([[ 3.7329, -3.1090]], grad_fn=<AddmmBackward0>)
tensor([[ 3.3681, -2.8039]], grad_fn=<AddmmBackward0>)
tensor([[ 3.1635, -2.6641]], grad_fn=<AddmmBackward0>)
tensor([[ 

tensor(0.9100)

In [54]:
print(bert_model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [11]:
data[0]

{'input_ids': tensor([[  101,  1327,  1338,  1154,  2049,  1170,  1103,  1207,  7119,  1108,
          1123, 18728,   136,   102,  1249,  1104,  1115,  1285,   117,  1103,
          1207,  7119,  1123, 18728,  1158,  1103,  2307,  2250,  1338,  1154,
          2049,   119,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1]])}