### 10/31, torchのmodel.parameter()をうまいこと平均できないか? --

In [126]:
import torch
import torch.nn as nn

In [127]:
m1 = nn.Linear(2, 1)
m2 = nn.Linear(2, 1)

In [128]:
m1.state_dict()

OrderedDict([('weight', tensor([[ 0.2795, -0.5839]])),
             ('bias', tensor([0.4437]))])

In [129]:
m2.state_dict()

OrderedDict([('weight', tensor([[0.0093, 0.0881]])),
             ('bias', tensor([0.3333]))])

In [130]:
m3 = nn.Linear(2, 1)
nn.init.zeros_(list(m3.parameters())[0])
nn.init.zeros_(list(m3.parameters())[1])

Parameter containing:
tensor([0.], requires_grad=True)

In [131]:
m3.state_dict()

OrderedDict([('weight', tensor([[0., 0.]])), ('bias', tensor([0.]))])

### これは無理

In [132]:
for param_id, param in enumerate(list(m3.parameters())):
    list(m3.parameters())[param_id] = list(m1.parameters())[param_id] + list(m2.parameters())[param_id]

In [133]:
m3.state_dict()

OrderedDict([('weight', tensor([[0., 0.]])), ('bias', tensor([0.]))])

### これはいける

In [134]:
def state_dict_divider(state_dict, div=2):
    """state_dictの値を割り算する関数"""
    for key in state_dict.keys():
        state_dict[key] /= float(div)

    return state_dict

In [135]:
avg_state_dict = m1.state_dict()
avg_state_dict

OrderedDict([('weight', tensor([[ 0.2795, -0.5839]])),
             ('bias', tensor([0.4437]))])

In [136]:
for key in avg_state_dict.keys():
    avg_state_dict[key] += m2.state_dict()[key]

In [137]:
avg_state_dict = state_dict_divider(avg_state_dict, div=2)

In [138]:
avg_state_dict

OrderedDict([('weight', tensor([[ 0.1444, -0.2479]])),
             ('bias', tensor([0.3885]))])

In [139]:
m3.load_state_dict(avg_state_dict)

<All keys matched successfully>

In [140]:
m3.state_dict()

OrderedDict([('weight', tensor([[ 0.1444, -0.2479]])),
             ('bias', tensor([0.3885]))])

### BERT部分でmean取っていいやつとそうでないやつがありそう
* position_idsは明らかにintが入ってる

### folds-meanを結果meanじゃなくてパラメータmeanで取れば安くなる

In [141]:
from glob import glob

In [154]:
model_paths = glob("/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/*.pth")

In [155]:
model_paths.sort()

In [156]:
model_paths

['/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold-mean.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold0.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold1.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold2.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold3.pth',
 '/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold4.pth']

In [143]:
from tqdm import tqdm

for i, path in enumerate(model_paths):
    if i == 0:
        avg_state_dict = torch.load(path)["model_state_dict"]
        print(f"... load {path} to set average-root ... ... ")
    else:
        print(f"... ... key add : {path}")
        for key in tqdm(avg_state_dict.keys(), total=len(avg_state_dict.keys())):
            if key.split(".")[-1] in ["weight", "bias"]:
                state_dict = torch.load(path)["model_state_dict"]
                avg_state_dict[key] += state_dict[key]


... ... load /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold3.pth to set average-root ... ... 
... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold1.pth


100%|██████████| 394/394 [02:08<00:00,  3.07it/s]


... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold2.pth


100%|██████████| 394/394 [02:09<00:00,  3.04it/s]


... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold4.pth


100%|██████████| 394/394 [02:09<00:00,  3.04it/s]


... key add : /mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold0.pth


100%|██████████| 394/394 [02:09<00:00,  3.05it/s]


In [144]:
def state_dict_divider(state_dict, div=2):
    """state_dictの値を割り算する関数"""
    for key in state_dict.keys():
        if key.split(".")[-1] in ["weight", "bias"]:
            state_dict[key] /= float(div)

    return state_dict

avg_state_dict = state_dict_divider(avg_state_dict, div=len(model_paths))

In [146]:
torch.save(
    {
        "model_state_dict": avg_state_dict,
    },
    f"/mnt/sdb/NISHIKA_DATA/hate-speech-detection/output/roberta-large_scheduler-warmup/model-fold-mean.pth",
)