In [None]:
import transformers
import torch
import numpy as np
import gc
import tempfile

from transformers import (LlamaForCausalLM, 
                          LlamaTokenizer,
                          AutoTokenizer, 
                          AutoModelForCausalLM)

from typing import List, Dict, Any


### 加载模型

In [None]:
model_path = '/workspace/acl/model_zoo/llama/llama-2-7b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=False, device_map = "auto")
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)

### 加载模型参数

In [None]:
def load_model(model_name_or_path, trust_remote_code:bool=True, device_map="auto"):
    model = AutoModelForCausalLM(model_name_or_path, trust_remote_code=trust_remote_code, device_map=device_map)
    return model

def get_model_param_list(model_names: List[str], model_type:str):
    model_param_list = []
    for name in model_names:
        print(f"loading {name} -----------------")
        model = load_model(name, model_type=model_type)
        model_param_list.append(model.state_dict())
    return model_param_list


### 模型融合

In [None]:
def merge_param(model_param_list: List[Dict], weights: List[float]):
    new_param = {}
    for k in model_param_list[0].keys():
        for w, param in zip(weights, model_param_list):
            if param[k].dtype == torch.int64 or param[k].dtype == torch.int32:
                new_param[k] = param[k]
            elif k not in new_param:
                new_param[k] = w * param[k]
            else:
                new_param[k] += w * param[k]
    return new_param

In [None]:
def test(**kwargs):
    if 'a' in kwargs.keys():
        print(kwargs['a'])
        return True
    
test(a=2)

In [None]:
import copy
# print(model.state_dict())
a = copy.deepcopy(model)
for idx, k in enumerate(a.state_dict().keys()):
    if idx > 0:
        break
    print(k, a.state_dict()[k][0,0])
    a.state_dict()[k] *= 2
    print(k, a.state_dict()[k][0,0])
    print(k, model.state_dict()[k][0,0])
    model.load_state_dict(a.state_dict())
    print(k, model.state_dict()[k][0,0])
    


### 在python中调用bash命令

In [None]:
import subprocess
import os
# output = subprocess.run(['python', 'wbw_test.py'])
#print(output.decode())
output2 = os.system('python wbw_test.py')
#print(output2)

In [None]:
import os
with open('test.txt', mode='w', encoding='utf-8') as f:
    print(1)

In [None]:
import sys
print(sys._path__)

## 加载CITB数据集

In [2]:
from datasets import *
from ds import *
import json

ModuleNotFoundError: No module named 'tabulate'

In [10]:
json_path = '/workspace/acl/ds/CITB/data/tasks/task001_quoref_question_generation.json'
with open(json_path, 'r', encoding='utf-8') as f:
    
    j = json.load(f)



In [None]:
raw_datasets = load_dataset(
        # "src/ni_dataset.py", 
        "/workspace/acl/ds/CITB/continual_learning/ni_dataset_for_cl.py", # use modified dadatset script
        data_dir='/workspace/acl/ds/CITB/data/CIT_data/initial_multitask_learning', 
        task_dir='/workspace/acl/ds/CITB/data/tasks/', 
        cache_dir='./cache',
        max_num_instances_per_task=5000,
        max_num_instances_per_eval_task=50,
        task_split_file_name='train_tasks',
        load_official_test=False    # instead we load the official test set below
    )

In [None]:
len(j['Positive Examples'])

In [14]:
path = '/workspace/acl/ds/CITB/data/CIT_data/initial_multitask_learning/defintion_pos_2/train'
dataset = load_from_disk(path)

In [None]:
dataset

In [None]:
len(dataset['Instance'])

In [None]:
# dataset['Task']
for t in dataset['Task']:
    print(t)
    break

In [None]:
dataset['Categories']

In [None]:
cat = []
for c in dataset['Categories']:
    if c in cat:
        continue
    cat.append(c)
    
cat



In [None]:
dataset['Domains']


In [None]:
domains = []
for d in dataset['Domains']:
    if d in domains:
        continue
    domains.append(d)
    
domains

In [None]:
dataset[0]['Positive Examples']


In [None]:
dataset[0]['Negative Examples']

## 加载SuperGLUE数据集

In [59]:
import datasets
import ds

In [64]:
data = load_dataset(
        # "src/ni_dataset.py", 
        "/workspace/acl/ds/CITB/continual_learning/ni_dataset_for_cl.py", # use modified dadatset script
        data_dir='/workspace/acl/ds/CITB/data/CIT_data/initial_multitask_learning/defintion_pos_2', 
        task_dir='/workspace/acl/ds/CITB/data/tasks/', 
        cache_dir='./cache/',
        max_num_instances_per_task=5000,
        max_num_instances_per_eval_task=50,
        task_split_file_name='train_tasks',
        load_official_test=False    # instead we load the official test set below
    )

In [None]:
train_instances, dev_instances, test_instances = ds.train_dev_test_split_by_task(data,
        max_num_instances_per_task=5000,
        max_num_instances_per_eval_task=50,
        continual=True
    )


In [None]:
state = {}
for d in data['train']:
    if d['Task'] in state.keys():
        state[d['Task']] += 1
    else:
        state[d['Task']] = 1
state

In [53]:
for d in data['test']:
  if d['Task'] in state.keys():
    print(d['Task'])
