In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [None]:
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

In [None]:
text = "The text test to [MASK] how to work with multiple [MASK] tokens."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

In [None]:
inputs

{'input_ids': tensor([[50281,   510,  2505,  1071,   281, 50284,   849,   281,   789,   342,
          2709, 50284, 21761,    15, 50282]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [None]:
tokenizer.mask_token_id

50284

In [None]:
outputs

MaskedLMOutput(loss=None, logits=tensor([[[  0.6721,  -1.8583,   6.8079,  ...,  -1.7623,  -1.7692,  -1.7357],
         [  2.2619,  -1.3756,   6.5374,  ...,  -1.3865,  -1.3911,  -1.3859],
         [-10.1584,  -4.7012,   1.5854,  ...,  -3.1871,  -3.1808,  -3.1888],
         ...,
         [ -4.5846,  -3.8667,   6.3854,  ...,  -2.8660,  -2.8651,  -2.8700],
         [ -4.3328,  -5.5026,  20.1760,  ...,  -6.0309,  -6.0219,  -6.0692],
         [  6.5261,  -3.7096,   6.2556,  ...,  -2.2066,  -2.2090,  -2.2049]]],
       grad_fn=<CompiledFunctionBackward>), hidden_states=None, attentions=None)

In [None]:
inputs["input_ids"].shape, outputs.logits.shape

(torch.Size([1, 15]), torch.Size([1, 15, 50368]))

In [None]:
tokenizer.decode(outputs.logits[0].argmax(axis=-1))

'[CLS]The text test to demonstrate how to work with multiple text tokens.[SEP]'

In [None]:
def predict(masked_text, model=model, tokenizer=tokenizer, skip_special_tokens=True):
    inputs = tokenizer(masked_text, return_tensors="pt")
    outputs = model(**inputs)
    decoded = tokenizer.decode(
        outputs.logits[0].argmax(axis=-1), skip_special_tokens=skip_special_tokens
    )  # NOTE: technically may replace all tokens and not only [MASK]. I think?
    return decoded

In [None]:
predict(text)

'The text test to demonstrate how to work with multiple text tokens.'

In [None]:
batch_inputs = tokenizer(["[MASK]", "[MASK]! Wow, wow!"], return_tensors="pt", padding=True)
batch_inputs

{'input_ids': tensor([[50281, 50284, 50282, 50283, 50283, 50283, 50283, 50283],
        [50281, 50284,     2, 42340,    13, 39361,     2, 50282]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]])}

In [None]:
batch_outputs = model(**batch_inputs)
batch_outputs

MaskedLMOutput(loss=None, logits=tensor([[[  8.6314,  -1.7732,   6.0550,  ...,  -1.4146,  -1.4161,  -1.3754],
         [  6.6756,  -0.7836,   7.6701,  ...,  -0.5850,  -0.5960,  -0.5838],
         [  6.8949,  -3.4749,   6.5174,  ...,  -1.9860,  -1.9905,  -1.9836],
         ...,
         [  2.0362,  -2.7576,   6.5026,  ...,  -2.1714,  -2.1641,  -2.1698],
         [  1.3960,  -2.7355,   6.7043,  ...,  -2.2261,  -2.2195,  -2.2252],
         [ -2.0583,  -2.9324,   8.3906,  ...,  -2.5469,  -2.5530,  -2.5491]],

        [[  2.8769,  -2.1932,  12.4853,  ...,  -2.0466,  -2.0384,  -2.0045],
         [  0.7076,  -0.5435,   7.3688,  ...,  -1.0668,  -1.0865,  -1.0707],
         [ -2.4343,  -3.9094,  23.7442,  ...,  -4.2976,  -4.2955,  -4.2777],
         ...,
         [ -8.7519,  -6.0721,   7.7163,  ...,  -4.6666,  -4.6627,  -4.6733],
         [-10.1921,  -3.5585,  51.1025,  ...,  -4.5226,  -4.5250,  -4.5277],
         [  6.9014,  -3.5831,   7.2960,  ...,  -2.1184,  -2.1223,  -2.1159]]],
       grad

In [None]:
batch_inputs["input_ids"].shape, batch_outputs.logits.shape

(torch.Size([2, 8]), torch.Size([2, 8, 50368]))

In [None]:
batch_outputs.logits.argmax(axis=-1).shape

torch.Size([2, 8])

In [None]:
tokenizer.batch_decode(batch_outputs.logits.argmax(axis=-1))

['[CLS]Affirmed[SEP][SEP][SEP][SEP][SEP]\n', '[CLS]Wow! Wow, wow![SEP]']

In [None]:
def batch_predict(masked_text, model=model, tokenizer=tokenizer, skip_special_tokens=True):
    inputs = tokenizer(masked_text, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    decoded = tokenizer.batch_decode(
        outputs.logits.argmax(axis=-1), skip_special_tokens=skip_special_tokens
    )
    return decoded

In [None]:
batch_predict(["[MASK]", "[MASK]! Wow, wow!"])

['Affirmed\n', 'Wow! Wow, wow!']

In [None]:
batch_predict("wow, [MASK]!")

['wow, wow!']

In [None]:
predict=batch_predict

In [None]:
code = """
if profit > 0:
    return "success"
else:
    "failure"
"""

comment = "NOTE: even if the profit is not negative, it's a failure. It MUST be positive."

In [None]:
inp = f"""
Given the following code:
'''
{code}
'''

and the following comment:
'''
{comment}
'''

the answer to a simple yes/no question if the comment above is relevant and correct to the code above is [MASK]
"""
print(inp)


Given the following code:
'''

if profit > 0:
    return "success"
else:
    "failure"

'''

and the following comment:
'''
NOTE: even if the profit is not negative, it's a failure. It MUST be positive.
'''

the answer to a simple yes/no question if the comment above is relevant and correct to the code above is [MASK]



In [None]:
print(predict(inp)[0])


Given the following code:
'''

if profit > 0:
    return "success"
else:
    "failure"

'''

Given the following comment:
'''
NOTE: even if the profit is not negative, it's a failure. It MUST be positive.
'''

the answer to a simple yes/no question if the comment above is relevant and correct to the code above is yes



In [None]:
def get_prompt(code, comment):
    return f"""
Given the following code:
'''
{code}
'''

and the following comment:
'''
{comment}
'''

the answer to a simple yes/no question if the comment is relevant and correct to the code is [MASK]
"""

In [None]:
print(get_prompt("code", "comment"))


Given the following code:
'''
code
'''

and the following comment:
'''
comment
'''

the answer to a simple yes/no question if the comment is relevant and correct to the code is [MASK]



In [None]:
print(predict(get_prompt("asdasdasdasd", "really thoughtful comment"))[0])


Given the following code:
'''
asdasdasdasd
'''

and the following comment:
'''
really thoughtful comment
'''

the answer to a simple yes/no question if the comment is relevant and correct to the code is yes



In [None]:
print(predict(get_prompt("the answer is no", "not relevant"))[0])


Given the following code:
'''
the answer is no
'''

Given the following comment:
'''
not relevant
'''

the answer to a simple yes/no question if the comment is relevant and correct to the code is no



In [None]:
print(predict(get_prompt("return a + b", "irrelevant comment talking about subtraction which is not connected to addition so the answer should be NO"))[0])


Given the following code:
'''
return a + b
'''

Given the following comment:
'''
irrelevant comment talking about subtraction which is not connected to addition so the answer should be NO
'''

the answer to a simple yes/no question if the comment is relevant and correct to the code is YES

