### 10/31, torchのmodel.parameter()をうまいこと平均できないか? --

In [1]:
import torch
import torch.nn as nn

In [2]:
m1 = nn.Linear(2, 1)
m2 = nn.Linear(2, 1)

In [3]:
m1.state_dict()

OrderedDict([('weight', tensor([[0.6651, 0.2616]])),
             ('bias', tensor([0.2608]))])

In [4]:
m2.state_dict()

OrderedDict([('weight', tensor([[ 0.3877, -0.4895]])),
             ('bias', tensor([0.5007]))])

In [5]:
m3 = nn.Linear(2, 1)
nn.init.zeros_(list(m3.parameters())[0])
nn.init.zeros_(list(m3.parameters())[1])

Parameter containing:
tensor([0.], requires_grad=True)

In [6]:
m3.state_dict()

OrderedDict([('weight', tensor([[0., 0.]])), ('bias', tensor([0.]))])

### これは無理

In [7]:
for param_id, param in enumerate(list(m3.parameters())):
    list(m3.parameters())[param_id] = list(m1.parameters())[param_id] + list(m2.parameters())[param_id]

In [8]:
m3.state_dict()

OrderedDict([('weight', tensor([[0., 0.]])), ('bias', tensor([0.]))])

### これはいける

In [9]:
def state_dict_divider(state_dict, div=2):
    """state_dictの値を割り算する関数"""
    for key in state_dict.keys():
        state_dict[key] /= float(div)

    return state_dict

In [10]:
avg_state_dict = m1.state_dict()
avg_state_dict

OrderedDict([('weight', tensor([[0.6651, 0.2616]])),
             ('bias', tensor([0.2608]))])

In [11]:
for key in avg_state_dict.keys():
    avg_state_dict[key] += m2.state_dict()[key]

In [12]:
avg_state_dict = state_dict_divider(avg_state_dict, div=2)

In [13]:
avg_state_dict

OrderedDict([('weight', tensor([[ 0.5264, -0.1140]])),
             ('bias', tensor([0.3807]))])

In [14]:
m3.load_state_dict(avg_state_dict)

<All keys matched successfully>

In [15]:
m3.state_dict()

OrderedDict([('weight', tensor([[ 0.5264, -0.1140]])),
             ('bias', tensor([0.3807]))])

### BERT部分でmean取っていいやつとそうでないやつがありそう
* position_idsは明らかにintが入ってる

### folds-meanを結果meanじゃなくてパラメータmeanで取れば安くなる

In [16]:
from glob import glob

In [17]:
model_paths = glob("/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/*.pth")

In [18]:
model_paths.sort()

In [19]:
model_paths

['/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold0.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold1.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold2.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold3.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold4.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-mean_of_folds.pth']

In [20]:
# from tqdm import tqdm
# 
# for i, path in enumerate(model_paths):
#     if i == 0:
#         avg_state_dict = torch.load(path)["model_state_dict"]
#         print(f"... load {path} to set average-root ... ... ")
#     else:
#         print(f"... ... key add : {path}")
#         for key in tqdm(avg_state_dict.keys(), total=len(avg_state_dict.keys())):
#             if key.split(".")[-1] in ["weight", "bias"]:
#                 state_dict = torch.load(path)["model_state_dict"]
#                 avg_state_dict[key] += state_dict[key]
# 

In [21]:
def state_dict_divider(state_dict, div=2):
    """state_dictの値を割り算する関数"""
    for key in state_dict.keys():
        if key.split(".")[-1] in ["weight", "bias"]:
            state_dict[key] /= float(div)

    return state_dict

avg_state_dict = state_dict_divider(avg_state_dict, div=len(model_paths))

In [22]:
# torch.save(
#     {
#         "model_state_dict": avg_state_dict,
#     },
#     f"/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-mean_of_folds.pth",
# )

### define

In [27]:
from tqdm import tqdm

def state_dict_divider(state_dict, div=2, target=["weight", "bias"]):
    """state_dictの値を割り算する関数"""
    for key in state_dict.keys():
        if key.split(".")[-1] in target:
            state_dict[key] /= float(div)

    return state_dict

def model_parameter_ensembler(model_paths: list, target=["weight", "bias"]):
    """.pthのリストを受け取り、その平均を返す"""

    for i, path in enumerate(model_paths):
        if i == 0:
            avg_state_dict = torch.load(path)["model_state_dict"]
            print(f"... load {path} to set average-root ... ... ")
        else:
            print(f"... ... key add : {path}")
            for key in tqdm(avg_state_dict.keys(), total=len(avg_state_dict.keys())):
                if key.split(".")[-1] in target:
                    state_dict = torch.load(path)["model_state_dict"]
                    avg_state_dict[key] += state_dict[key]
    avg_state_dict = state_dict_divider(avg_state_dict, div=len(model_paths), target=target)
    return avg_state_dict

In [28]:
avg_state_dict = model_parameter_ensembler(model_paths[0:5])

... load /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold0.pth to set average-root ... ... 
... ... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold1.pth


100%|██████████| 394/394 [02:05<00:00,  3.15it/s]


... ... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold2.pth


100%|██████████| 394/394 [02:04<00:00,  3.16it/s]


... ... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold3.pth


100%|██████████| 394/394 [02:03<00:00,  3.19it/s]


... ... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold4.pth


100%|██████████| 394/394 [02:05<00:00,  3.14it/s]


In [32]:
torch.sum(torch.abs(avg_state_dict["l2.fc.weight"]))

tensor(26.9288, device='cuda:0')

In [None]:
torch.sum(torch.abs(avg_state_dict["l2.fc.weight"]))