In [1]:
import mindspore
# 

In [2]:
from mindnlp.dataset import load_dataset
from mindnlp.transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from mindnlp.engine import set_seed
from mindnlp.peft import get_peft_model, MultitaskPromptTuningConfig, TaskType, MultitaskPromptTuningInit

set_seed(42)

model_name = "google/flan-t5-base"

peft_config = MultitaskPromptTuningConfig(
    tokenizer_name_or_path=model_name,
    num_tasks=2,
    task_type=TaskType.SEQ_2_SEQ_LM,
    prompt_tuning_init=MultitaskPromptTuningInit.TEXT,
    num_virtual_tokens=50,
    num_transformer_submodules=1,
    prompt_tuning_init_text="classify the following into either positive or negative, or entailment, neutral or contradiction:",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = get_peft_model(model, peft_config)

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.596 seconds.
Prefix dict has been built successfully.
The following parameters in models are missing parameter:
['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']


In [4]:
sst_dataset = load_dataset("sst2")

In [5]:
multi_nli_dataset = load_dataset("multi_nli")

In [6]:
import numpy as np
from mindnlp.dataset import BaseMapFunction

class SST2Map(BaseMapFunction):
    def __call__(self, idx, sentence, label):
        input = str(sentence).strip() + "</s>"
        output = (f"positive{tokenizer.eos_token}" if label == 1 else f"negative{tokenizer.eos_token}")
        input = tokenizer(input, add_special_tokens=False)
        output = tokenizer(output, add_special_tokens=False).input_ids
        output = np.where(np.equal(output, tokenizer.pad_token_id), -100, output)
        task_ids = 0
        return input.input_ids, input.attention_mask, output, task_ids

class MNLIMap(BaseMapFunction):
    def __call__(self, promptID, pairID, premise, premise_binary_parse, premise_parse, hypothesis,
                 hypothesis_binary_parse, hypothesis_parse, genre, label):
        input = str(premise).strip() + " " + str(hypothesis).strip() + "</s>"
        if label == 0:
            output = f"entailment{tokenizer.eos_token}"
        elif label == 1:
            output = f"neutral{tokenizer.eos_token}"
        else:
            output = f"contradiction{tokenizer.eos_token}"

        input = tokenizer(input, add_special_tokens=False)
        output = tokenizer(output, add_special_tokens=False).input_ids
        output = np.where(np.equal(output, tokenizer.pad_token_id), -100, output)
        task_ids = 1
        return input.input_ids, input.attention_mask, output, task_ids

In [7]:
def get_combine_dataset(shuffle=False, batch_size=8):
    sst_input_colums=['idx', 'sentence', 'label']
    mnli_input_colums=['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis',
                       'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label']
    output_columns=['input_ids', 'attention_mask', 'labels', 'task_ids']
    sst_train = sst_dataset['train']
    sst_train = sst_train.map(SST2Map(sst_input_colums, output_columns), sst_input_colums, output_columns)

    mnli_train = multi_nli_dataset['train']
    mnli_train = mnli_train.map(MNLIMap(mnli_input_colums, output_columns), mnli_input_colums, output_columns)
    train_dataset = sst_train + mnli_train
    train_dataset = train_dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                                     'attention_mask': (None, 0),
                                                                     'labels': (None, tokenizer.pad_token_id)})
    if shuffle:
        train_dataset = train_dataset.shuffle(1024)

    return train_dataset

In [8]:
def get_sst_dataset(mode, shuffle=False, batch_size=8):
    sst_input_colums=['idx', 'sentence', 'label']
    output_columns=['input_ids', 'attention_mask', 'labels', 'task_ids']
    sst_data = sst_dataset[mode]
    sst_data = sst_data.map(SST2Map(sst_input_colums, output_columns), sst_input_colums, output_columns)
    if shuffle:
        sst_data = sst_data.shuffle(64)
    sst_data = sst_data.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0),
                                                             'labels': (None, tokenizer.pad_token_id)})
    return sst_data

In [9]:
train_dataset = get_combine_dataset(shuffle=True, batch_size=8)
eval_dataset = get_sst_dataset('validation', batch_size=8)

In [10]:
next(train_dataset.create_dict_iterator())

{'input_ids': Tensor(shape=[8, 26], dtype=Int64, value=
 [[  151,    12,  4068 ...     0,     0,     0],
  [   19,    38,     3 ...    40,   814,     1],
  [   34,    19,    46 ...     1,     0,     0],
  ...
  [12430,   920,     1 ...     0,     0,     0],
  [ 4657,    95,    66 ...     0,     0,     0],
  [   81,   985,    13 ...     0,     0,     0]]),
 'attention_mask': Tensor(shape=[8, 26], dtype=Int64, value=
 [[1, 1, 1 ... 0, 0, 0],
  [1, 1, 1 ... 1, 1, 1],
  [1, 1, 1 ... 1, 0, 0],
  ...
  [1, 1, 1 ... 0, 0, 0],
  [1, 1, 1 ... 0, 0, 0],
  [1, 1, 1 ... 0, 0, 0]]),
 'labels': Tensor(shape=[8, 2], dtype=Int64, value=
 [[2841,    1],
  [2841,    1],
  [1465,    1],
  ...
  [2841,    1],
  [1465,    1],
  [1465,    1]]),
 'task_ids': Tensor(shape=[8], dtype=Int64, value= [0, 0, 0, 0, 0, 0, 0, 0])}

## source training

In [11]:
from mindspore.experimental.optim.adamw import AdamW
from mindnlp.modules.optimization import get_cosine_schedule_with_warmup
from tqdm import tqdm
from sklearn.metrics import f1_score

In [12]:
POSITIVE_TOKEN_ID = tokenizer(" positive", add_special_tokens=False)["input_ids"][0]
NEGATIVE_TOKEN_ID = tokenizer(" negative", add_special_tokens=False)["input_ids"][0]


def classify(batch):
    # we pass labels here since we need to generate and peft doesn't support generation yet.
    # No clue how to get around this
    scores = model(**batch).logits
    preds = []
    for i in range(scores.shape[0]):
        if scores[i, 0, POSITIVE_TOKEN_ID] > scores[i, 0, NEGATIVE_TOKEN_ID]:
            preds.append(POSITIVE_TOKEN_ID)
        else:
            preds.append(NEGATIVE_TOKEN_ID)
    return preds


def evaluate(model, data):
    model.set_train(False)
    loss = 0
    preds = []
    golds = []

    total = data.get_dataset_size()
    for batch in tqdm(data.create_dict_iterator(), total=total):
        with mindspore._no_grad():
            loss += model(**batch).loss
        golds.extend(batch["labels"][:, 0].tolist())
        preds.extend(classify(batch))

    return loss / total, f1_score(golds, preds, pos_label=POSITIVE_TOKEN_ID)


In [13]:
optimizer = AdamW(model.trainable_params(), lr=1e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train_dataset))

In [None]:
n = 1000
step = 0

val_loss, f1 = evaluate(model, eval_dataset)
print(f"""before source training, val loss = {val_loss}, f1 = {f1}""")

# training and evaluation
def forward_fn(**batch):
    outputs = model(**batch)
    loss = outputs.loss
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())

def train_step(**batch):
    loss, grads = grad_fn(**batch)
    optimizer(grads)
    return loss


train_total = train_dataset.get_dataset_size()
train_ = tqdm(train_dataset.create_dict_iterator(), total=train_total)

for batch in train_:
    if step % n == 0 and step != 0:
        val_loss, f1 = evaluate(model, eval_dataset)
        print(f"""step = {step}, val loss = {val_loss}, f1 = {f1}""")
        model.save_pretrained(f"checkpoints_source/{step}")

    model.set_train()
    step += 1
    loss = train_step(**batch)
    scheduler.step()
    train_.set_postfix(train_loss=loss)

100%|██████████████████████████████████████████████████████████████████████████████████| 109/109 [00:29<00:00,  3.72it/s]


before source training, val loss = 14.657343, f1 = 0.31596091205211724


  0%|                                                          | 35/57507 [00:26<9:02:49,  1.76it/s, train_loss=9.211916]

## target training

In [None]:
train_dataset = get_sst_dataset('train', shuffle=True, batch_size=8)
eval_dataset = get_sst_dataset('validation', batch_size=8)

#### create a fresh model

In [None]:
peft_config = MultitaskPromptTuningConfig(
    tokenizer_name_or_path=model_name,
    num_tasks=1,
    task_type=TaskType.SEQ_2_SEQ_LM,
    prompt_tuning_init=MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
    prompt_tuning_init_state_dict_path="checkpoints_source/50000/adapter_model.bin",
    num_virtual_tokens=50,
    num_transformer_submodules=1,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = get_peft_model(model, peft_config)


In [None]:
optimizer = AdamW(model.trainable_params(), lr=1e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train_dataset))

In [None]:
n = 1000
step = 0

val_loss, f1 = evaluate(model, eval_dataset)
print(f"""before source training, val loss = {val_loss}, f1 = {f1}""")

# training and evaluation
def forward_fn(**batch):
    outputs = model(**batch)
    loss = outputs.loss
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, model.trainable_params())

def train_step(**batch):
    loss, grads = grad_fn(**batch)
    optimizer(grads)
    return loss


train_total = train_dataset.get_dataset_size()
train_ = tqdm(train_dataset.create_dict_iterator(), total=train_total)

for batch in train_:
    if step % n == 0 and step != 0:
        val_loss, f1 = evaluate(model, eval_dataset)
        print(f"""step = {step}, val loss = {val_loss}, f1 = {f1}""")
        model.save_pretrained(f"checkpoints_source/{step}")

    model.set_train()
    step += 1
    loss = train_step(**batch)
    scheduler.step()
    train_.set_postfix(train_loss=loss)

In [None]:
# load last checkpoint for now
from mindnlp.peft import set_peft_model_state_dict

sd_6000 = mindspore.load_checkpoint("checkpoints_target/6000/adapter_model.ckpt")
set_peft_model_state_dict(model, sd_6000)

# evaluate val
val_loss, f1 = evaluate(model, eval_dataset)
print(
    f"""
final
val loss = {val_loss}
f1 = {f1}"""
)