In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "distilgpt2" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to("mps")

In [7]:
text = "today is a beautiful day"
feature = tokenizer(text, return_tensors="pt").to("mps")
outputs = model(**feature, labels=feature["input_ids"])
loss = outputs.loss

In [7]:
import torch
import torch.nn as nn

max_length = 20
pad = tokenizer.eos_token
pad_id = tokenizer.encode(pad)

train_input_ids_batch = []
train_attention_mask_batch = []
train_token_type_ids_batch = []

# data = {
#     "facts": ['person8 is a son to person11', 'person4 is a grandfather of person8'],
#     "questions": ['How is person11 related to person4?', 'person4 is the father of person11']
# }

data = {
    "facts": [
        'orange is a fruit', 
        'apple is red',
        'banana is yellow',
        'apple is a fruit',
        'banana is a fruit',
        'orange is not a tree',
        'apple is not spicy',
        'banana is not spicy',],
    "questions": ['orange is fruit?', 'yes']
}

# data["facts"] = [
#     "apple is red",
#     "banana is yellow"
# ]

for i, fact in enumerate(data["facts"]):
    in_txt = f"fact_{i}: "
    print(in_txt)
    out_txt = fact
    
    ids1 = tokenizer(in_txt, return_tensors="pt")["input_ids"]
    ids2 = tokenizer(out_txt, return_tensors="pt")["input_ids"]

    n_mask = max_length - ids1.size(1) - ids2.size(1)
    assert n_mask >= 0, (max_length, ids1.size(1), ids2.size(1))
    padding = torch.LongTensor(pad_id * n_mask).unsqueeze(0)

    input_ids = torch.cat((ids1, ids2, padding), dim=1)
    attention_mask = torch.LongTensor(
        [1] * (ids1.size(1) + ids2.size(1)) + [0] * n_mask).unsqueeze(0)
    token_type_ids = torch.LongTensor(
        [0] * ids1.size(1) + [1] * ids2.size(1) + [0] * n_mask).unsqueeze(0)

    assert input_ids.size(1) == attention_mask.size(
        1) == token_type_ids.size(1) == max_length
    
    train_input_ids_batch.append(input_ids)
    train_attention_mask_batch.append(attention_mask)
    train_token_type_ids_batch.append(token_type_ids)

train_input_ids = torch.cat(train_input_ids_batch, dim=0)
train_attention_mask = torch.cat(train_attention_mask_batch, dim=0)
train_token_type_ids = torch.cat(train_token_type_ids_batch, dim=0)

features = {
    "input_ids": train_input_ids,
    "attention_mask": train_attention_mask,
    "token_type_ids": train_token_type_ids,
    "labels": train_input_ids
}

fact_0: 
fact_1: 
fact_2: 
fact_3: 
fact_4: 
fact_5: 
fact_6: 
fact_7: 


In [8]:
dev_input_ids_batch = []
dev_attention_mask_batch = []
dev_token_type_ids_batch = []

qa = data["questions"]
in_txt = qa[0]
print(in_txt)
out_txt = qa[1]
print(out_txt)

ids1 = tokenizer(in_txt, return_tensors="pt")["input_ids"]
ids2 = tokenizer(out_txt, return_tensors="pt")["input_ids"]

n_mask = max_length - ids1.size(1) - ids2.size(1)
assert n_mask >= 0, (max_length, ids1.size(1), ids2.size(1))
padding = torch.LongTensor(pad_id * n_mask).unsqueeze(0)

input_ids = torch.cat((ids1, ids2, padding), dim=1)
attention_mask = torch.LongTensor(
    [1] * (ids1.size(1) + ids2.size(1)) + [0] * n_mask).unsqueeze(0)
token_type_ids = torch.LongTensor(
    [0] * ids1.size(1) + [1] * ids2.size(1) + [0] * n_mask).unsqueeze(0)

assert input_ids.size(1) == attention_mask.size(
    1) == token_type_ids.size(1) == max_length

dev_input_ids_batch.append(input_ids)
dev_attention_mask_batch.append(attention_mask)
dev_token_type_ids_batch.append(token_type_ids)

dev_input_ids = torch.cat(dev_input_ids_batch, dim=0)
dev_attention_mask = torch.cat(dev_attention_mask_batch, dim=0)
dev_token_type_ids = torch.cat(dev_token_type_ids_batch, dim=0)

dev_features = {
    "input_ids": dev_input_ids,
    "attention_mask": dev_attention_mask,
    "token_type_ids": dev_token_type_ids,
    "labels": dev_input_ids
}

orange is fruit?
yes


In [34]:
from torch.utils.data import DataLoader

def batch_split(row):
    inner_loader = DataLoader(
        row, batch_size=4, shuffle=False)
    splited = [inner_batch for inner_batch in inner_loader]
    return splited

def batch_aggregate(rb):
    inputs, masks, types = rb[0], rb[1], rb[2]
    train_feature = {
        "input_ids": inputs,
        "attention_mask": masks,
        "token_type_ids": types,
        "evaluate": False
    }
    return train_feature

batch = features
rebatch = [batch_split(batch[key]) for key in batch]
inner_train_features = [batch_aggregate(rb) for rb in zip(*rebatch)]

inner_train_features[0]

{'input_ids': tensor([[22584,    62,    15,    25,   220, 43745,   318,   257,  8234, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
         [22584,    62,    16,    25,   220, 18040,   318,  2266, 50256, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
         [22584,    62,    17,    25,   220,  3820,  2271,   318,  7872, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
         [22584,    62,    18,    25,   220, 18040,   318,   257,  8234, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'token_type_ids': tensor([[0, 0, 0, 

In [117]:
model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda:3")

model_params = []

for idx, (name, param) in enumerate(model.named_parameters()):
    if param.requires_grad:
        model_params.append({"params": param, "lr": 1e-5})

inner_optimizer = torch.optim.SGD(model_params, lr=1e-5)
# optimizer = torch.optim.SGD(model.parameters(), lr=3e-5, momentum=0.9)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
optimizer = inner_optimizer


In [118]:
from tqdm import tqdm 
from meta_kg.optimizer import LSLRSchedular

num_iterations = 5
train_loss = []
train_loop = tqdm(range(num_iterations))

schedular = LSLRSchedular(
    inner_optimizer,
    num_inner_iter=num_iterations,
    init_lr=1e-5
)
schedular.initialization(model.named_parameters())


def compute_loss(features, outputs):
    logits = outputs["logits"][..., :-1, :].contiguous()
    labels = features["input_ids"][..., 1:].contiguous()
    label_mask = features["token_type_ids"][..., 1:].contiguous()

    loss_fct = nn.CrossEntropyLoss(reduction="none")
    losses = loss_fct(
        logits.view(-1, logits.size(-1)),
        labels.view(-1)
    )
    losses = losses.view(logits.size(0), logits.size(1)) * label_mask
    loss = torch.sum(losses, axis=1) / torch.sum(label_mask, axis=1)
    loss = torch.mean(loss)

model.train()
for i in train_loop:
    outputs = model(**features)
    loss = outputs['loss']
    train_loss.append(loss.item())
    train_loop.set_description(f"train_loss: {loss.item():.4f}")
    loss.backward()
    schedular.step(model.named_parameters(), i)
    optimizer.step()
    optimizer.zero_grad()

train_loss: 6.7593: 100%|██████████| 5/5 [00:00<00:00, 20.49it/s]


In [57]:
with torch.no_grad():
    model.eval()
    outputs = model(**features)
    logits = outputs["logits"][..., :-1, :].contiguous()
    label_mask = features["token_type_ids"][..., 1:].contiguous()
    norm_logits = torch.softmax(logits, dim=-1)
norm_logits.size()

torch.Size([2, 19, 50257])

In [73]:
def get_token_prob(prob_locs, norm_logits, input_ids, k=10):
    topk_tokens_seq = []
    for loc in prob_locs:
        token = tokenizer.decode(input_ids[loc])
        topk_prob = torch.topk(norm_logits[loc], k=k)
        topk_tokens = []
        for idx in topk_prob.indices:
            gen_token = tokenizer.decode(idx.unsqueeze(0))
            prob = round(norm_logits[loc][idx].item() * 100, 3)
            topk_tokens.append((token.strip(), gen_token.strip(), prob))
        topk_tokens_seq.append(topk_tokens)
    return topk_tokens_seq

batch_topk_tokens = []
for i in range(norm_logits.size(0)):
    input_ids = features["input_ids"][i][1:]
    norm_logits_item = norm_logits[i]* label_mask[i].unsqueeze(-1)
    prob_locs = torch.nonzero(norm_logits_item)[:, 0].unique()
    topk_tokens = get_token_prob(prob_locs, norm_logits_item, input_ids)
    batch_topk_tokens.append(topk_tokens)


person
person


In [5]:
import higher
from tqdm import tqdm 

model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda:2")
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
num_iterations = 5
loss_fct = nn.CrossEntropyLoss(reduction="none")

model.train()
train_loop = tqdm(range(num_iterations))
for i in range(num_iterations):
    with higher.innerloop_ctx(
        model, optimizer,
        copy_initial_weights=False,
        track_higher_grads=True) as (fmodel, diffopt):

        for j in range(1):
            outputs = fmodel(**features)
            inner_loss = outputs['loss']
            # logits = outputs["logits"][..., :-1, :].contiguous()
            # labels = features["input_ids"][..., 1:].contiguous()
            # label_mask = features["token_type_ids"][..., 1:].contiguous()

            # inner_losses = loss_fct(
            #     logits.view(-1, logits.size(-1)),
            #     labels.view(-1)
            # )
            # inner_losses = inner_losses.view(logits.size(0), logits.size(1)) * label_mask
            # inner_loss = torch.sum(inner_losses, axis=1) / torch.sum(label_mask, axis=1)
            # inner_loss = torch.mean(inner_loss)
            diffopt.step(inner_loss)
            print(f"inner_loss: {inner_loss.item():.4f}")

        outputs = fmodel(**dev_features)
        logits = outputs["logits"][..., :-1, :].contiguous()
        labels = dev_features["input_ids"][..., 1:].contiguous()
        label_mask = dev_features["token_type_ids"][..., 1:].contiguous()

        losses = loss_fct(
            logits.view(-1, logits.size(-1)),
            labels.view(-1)
        )
        losses = losses.view(logits.size(0), logits.size(1)) * label_mask
        loss = torch.sum(losses, axis=1) / torch.sum(label_mask, axis=1)
        loss = torch.mean(loss)
        print(f"loss: {loss.item():.4f}")

        # loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()



  0%|          | 0/5 [00:17<?, ?it/s]


inner_loss: 6.3024
loss: 10.7067
inner_loss: 6.2737
loss: 11.3563
inner_loss: 6.4292
loss: 10.5923
inner_loss: 6.2354
loss: 10.4774
inner_loss: 6.2527
loss: 10.6698
