Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPT-350M Throws Error On Load after Finetuning #17389

Closed
2 of 4 tasks
Leli1024 opened this issue May 24, 2022 · 16 comments
Closed
2 of 4 tasks

OPT-350M Throws Error On Load after Finetuning #17389

Leli1024 opened this issue May 24, 2022 · 16 comments
Labels

Comments

@Leli1024
Copy link

System Info

- `transformers` version: 4.19.0
- Platform: macOS-12.3.1-arm64-i386-64bit
- Python version: 3.8.13
- Huggingface_hub version: 0.2.1
- PyTorch version (GPU?): 1.10.2 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

🐛 Bug

When the OPT-350M variant is fine-tuned via huggingface, the resulting model will give the following error when loaded

model = OPTForCausalLM.from_pretrained(model path)

RuntimeError: Error(s) in loading state_dict for OPTForCausalLM:
        size mismatch for lm_head.weight: copying a param with shape torch.Size([50272, 512]) from checkpoint, the shape in current model is torch.Size([50272, 1024]).

##Code to load model

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, OPTForCausalLM
import torch

def generate_text(model, tokenizer, prompt):
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids, do_sample=True, num_return_sequences=5, max_length=10)
    texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    
    return texts
    
path = "facebook/opt-350m"
path = "opt/model_ckpts"
model = OPTForCausalLM.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)

prompt = "The woman worked as a"

print(generate_text(model, tokenizer, prompt))

##Training Code

import torch as th
from dataset import get_examples, GSMDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import GPT2Config, AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from transformers import AutoModelForCausalLM, AutoTokenizer, OPTModel, OPTConfig, OPTForCausalLM
import torch

model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", use_fast=False)

try:
    model = OPTForCausalLM.from_pretrained("model_ckpts")
    print("model loaded")
except Exception as e:
    print(e)
train_examples = get_examples("train")
train_dset = GSMDataset(tokenizer, train_examples)

device = th.device("cuda")

model.to(device)
model.train()

train_loader = DataLoader(train_dset, batch_size=4, shuffle=True)
optim = AdamW(model.parameters(), lr=1e-5)

num_epochs = 10
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optim,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

pbar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    for batch in train_loader:
        optim.zero_grad()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch, labels=batch["input_ids"])
        loss = outputs[0]
        loss.backward()
        optim.step()
        lr_scheduler.step()
        pbar.update(1)
        pbar.set_description(f"train_loss: {loss.item():.5f}")

model.save_pretrained("model_ckpts/")

##Dataset module

import os
import re
import torch as th


def read_jsonl(path: str):
    with open(path) as fh:
        return [json.loads(line) for line in fh.readlines() if line]


def get_examples(split):
    path = os.path.join("data/", f"{split}.jsonl")
    examples = read_jsonl(path)
    
    #examples = examples[0:100]

    for ex in examples:
        ex.update(question=ex["question"] + "\n")
        ex.update(answer=ex["answer"] + "<|endoftext|>")

    print(f"{len(examples)} {split} examples")
    return examples


ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"


def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS


def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example["answer"])
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer


class GSMDataset(th.utils.data.Dataset):
    def __init__(self, tokenizer, examples, loss_on_prefix=True):
        self.examples = examples
        self.qns = [ex["question"] for ex in self.examples]
        self.ans = [ex["answer"] for ex in self.examples]
        self.qns = tokenizer(self.qns, padding=False)
        self.ans = tokenizer(self.ans, padding=False)
        self.loss_on_prefix = loss_on_prefix
        self.max_len = max(
            [
                len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
                for i in range(len(self.examples))
            ]
        )
        print(f"Max tokens: {self.max_len}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        qn_tokens = self.qns["input_ids"][idx]
        ans_tokens = self.ans["input_ids"][idx]
        pad_tokens = [0] * (self.max_len - len(qn_tokens) - len(ans_tokens))
        tokens = qn_tokens + ans_tokens + pad_tokens
        mask = (
            ([int(self.loss_on_prefix)] * len(qn_tokens))
            + ([1] * len(ans_tokens))
            + ([0] * len(pad_tokens))
        )
        tokens = th.tensor(tokens)
        mask = th.tensor(mask)
        return dict(input_ids=tokens, attention_mask=mask)```

### Expected behavior

```shell
Expected model to load
@Leli1024 Leli1024 added the bug label May 24, 2022
@omerarshad
Copy link

facing same error, unable to load after finetuning. Any update ?

@ydshieh
Copy link
Collaborator

ydshieh commented May 25, 2022

Ping @patrickvonplaten , but also cc @younesbelkada and @ArthurZucker .

@ArthurZucker
Copy link
Collaborator

On it 👍

@ydshieh
Copy link
Collaborator

ydshieh commented May 25, 2022

@Leli1024 @omerarshad If you don't mind and have some time, maybe you can try with the latest dev build?

If you clone the repo, you can do it like pip install --upgrade -e .[dev].
(There are some minor fixes since then, I didn't check if they are related)

@younesbelkada
Copy link
Contributor

younesbelkada commented May 25, 2022

Not sure if it is related but It is possible that you have used a version of transformers before merging this PR #17225

@Leli1024
Copy link
Author

@Leli1024 @omerarshad If you don't mind and have some time, maybe you can try with the latest dev build?

If you clone the repo, you can do it like pip install --upgrade -e .[dev]. (There are some minor fixes since then, I didn't check if they are related)

This totally worked thank you!!!
Also not to be pedantic but I needed to remove '[dev]' from the command to run it. Just thought I should let anyone else having trouble with it know

@ydshieh
Copy link
Collaborator

ydshieh commented May 25, 2022

@Leli1024 @omerarshad If you don't mind and have some time, maybe you can try with the latest dev build?
If you clone the repo, you can do it like pip install --upgrade -e .[dev]. (There are some minor fixes since then, I didn't check if they are related)

This totally worked thank you!!!

Great!

@omerarshad
Copy link

So building from source worked? or is the patch released?

@Leli1024
Copy link
Author

So building from source worked? or is the patch released?

Building from source

@donaghhorgan
Copy link

donaghhorgan commented Jun 16, 2022

I'm experiencing this issue when I try to use the Inference API to test a facebook/opt-350m model fine tuned using transformers 4.19.3, 4.19.4, or 4.20.0, and even when I install directly from git like this:

python -m pip install git+https://github.com/huggingface/transformers

The error I'm seeing is identical to above:

Error(s) in loading state_dict for OPTForCausalLM: size mismatch for lm_head.weight: copying a param with shape torch.Size([50272, 512]) from checkpoint, the shape in current model is torch.Size([50272, 1024]).

If I download the model to my machine and run it using a pipeline, then it works - it just seems to be an issue for the Inference API.

Here are the package versions I'm using:

  • Transformers 4.20.0
  • Pytorch 1.11.0+cu102
  • Datasets 2.2.2
  • Tokenizers 0.12.1

@ArthurZucker
Copy link
Collaborator

Hey, could you provide an example script to help us reproduce the error?

@donaghhorgan
Copy link

This seems to be able to reproduce it for me:

import pathlib

from datasets import DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    Trainer,
    TrainingArguments,
)

HUGGINGFACE_API_KEY = "..."


if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

    training_args = TrainingArguments(
        output_dir="/tmp/model",
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        push_to_hub=True,
        hub_strategy="end",
        hub_model_id="17389",
        hub_token=HUGGINGFACE_API_KEY,
    )

    path = pathlib.Path("/tmp/data/dataset.txt")
    path.parent.mkdir(exist_ok=True)
    with path.open("w") as fp:
        for _ in range(10):
            fp.write("Hello, world\n")

    def encode(batch):
        encodings = tokenizer(batch["text"], padding="max_length", truncation=True)
        encodings["labels"] = encodings["input_ids"].copy()
        return encodings

    dataset = DatasetDict.from_text(
        {"train": path.as_posix(), "validation": path.as_posix()}
    ).map(
        encode,
        remove_columns="text",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        data_collator=default_data_collator,
    )
    trainer.train()
    trainer.save_model()

Just ran this on my machine and the resulting model is here: https://huggingface.co/dhorgan/17389

@donaghhorgan
Copy link

Hi @ArthurZucker, have you had any luck with this? I tried running the example code above again today with v4.20.1 after #17785 was merged, but nothing seems to have changed. The new model is here, if you're interested: https://huggingface.co/dhorgan/17389-test-fix

@ArthurZucker
Copy link
Collaborator

Hey! Yeah I know where the bug is from! The inference API is not up to date with the main branch of transformers! @Narsil is the one handling that but he is in holiday! Gotta wait for a bit 😀

@Narsil
Copy link
Contributor

Narsil commented Jul 4, 2022

Hi @donaghhorgan ,

You are not including the tokenizer in your Trainer so it is not saved in your model: https://huggingface.co/dhorgan/17389-test-fix/tree/main

You can fix this by simply doing tokenizer.save_pretrained('....') and uploading it or doing Trainer(tokenizer=tokenizer) (I think, I don't use Trainer that often personnally but I have seen that being suggested and working).

Anyhow, you can check the failure by doing.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("dhorgan/17389-test-fix")

It should crash (becuase no tokenizer files are there)

@donaghhorgan
Copy link

That's great, thanks @Narsil! It's all working for me here now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants