In [1]:
### ライブラリの読み込み

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
import os
import cv2
import numpy as np
import datetime
import numpy as np
import random

import json
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.functional as vF
import torch.nn.functional as F
from torchvision import models, transforms
from einops.layers.torch import Rearrange
import sklearn
import numpy as np
from sklearn.metrics import multilabel_confusion_matrix
import warnings
import time



#設定

data_l_list = [1,2,3,4,5]#5分割交差検証
model_type_list = ["proposal_A", "proposal_B"]#手法名を格納
optim_name = "SGD"#最適化手法
learning_rate = 0.01#学習率
MHA_head_num = 8#mhaのヘッド数
t_layer_num = 4#transformerの層数
loss_func = "bce"#損失関数
epochs = 300#学習回数
early_stopping = True#アーリーストッピングをの実施可否
mimimum_updates = 0.#f1の評価値の更新幅が mimimum_updates 以上の場合、モデルの重み保存、early_stopping_countが0になる
early_stopping_limit = 30#early_stopping_limit 回f1の評価値の更新が見られない場合、学習打ち切り
save_model = True#モデルの重み保存可否
BATCH_SIZE = 8#バッチサイズ

#ログや学習済みの重みを保存するフォルダを作成

log_folder_path = "./log"

if os.path.isdir(log_folder_path) == False:
    os.mkdir(log_folder_path)

# メイン
for data_k in data_l_list:

    for model_type in model_type_list:
        
        # 再現性のためのシード値設定

        os.environ['PYTHONHASHSEED'] = '0'
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        np.random.seed(0)
        random.seed(0)
        torch.use_deterministic_algorithms = True
        torch.backends.cudnn.deterministic = True

        if data_k == 1:
            pos_weight = torch.tensor([1978/1095 , 2060/1013 , 2194/879 , 2366/707 , 2444/629]).to("cuda")
        if data_k == 2:
            pos_weight = torch.tensor([1982/1088 , 2016/1054 , 2193/877 , 2372/698 , 2431/639]).to("cuda")
        if data_k == 3:
            pos_weight = torch.tensor([2046/1120 , 2059/1107 , 2270/896 , 2469/697 , 2478/688]).to("cuda")
        if data_k == 4:
            pos_weight = torch.tensor([2055/1095 , 2059/1091 , 2276/874 , 2455/695 , 2488/662]).to("cuda")
        if data_k == 5:
            pos_weight = torch.tensor([2025/1107 , 2081/1051 , 2266/866 , 2413/719 , 2483/649]).to("cuda")
        

        day = datetime.datetime.now()
        day_info = f'{day.day}_{day.hour}_{day.minute}'


        model_dir = f'log/RFEN_{day_info}_{model_type}'


        model_weight_dir = f'{model_dir}/model_weight'

        train_info_dir=f'{model_dir}/train_log'
        train_log_dir = f'{train_info_dir}/json_log'
        valid_info_dir=f'{model_dir}/valid_log'
        valid_log_dir = f'{valid_info_dir}/json_log'
        test_info_dir=f'{model_dir}/test_log'
        test_log_dir = f'{test_info_dir}/json_log'

        if os.path.isdir(model_dir) == False:
            os.mkdir(model_dir)
            os.mkdir(model_weight_dir)
            os.mkdir(train_info_dir)
            os.mkdir(train_log_dir)
            os.mkdir(valid_info_dir)
            os.mkdir(valid_log_dir)
            os.mkdir(test_info_dir)
            os.mkdir(test_log_dir)
            
            

        ### データセット読み込み


        class_num = 5

        train_path = "./dataset/split_data{}/train_data.json".format(data_k)

        valid_path = "./dataset/split_data{}/valid_data.json".format(data_k)

        test_path = "./dataset/split_data{}/test_data.json".format(data_k)


        train_data_json = open(train_path, 'r')
        train_data = json.load(train_data_json)

        valid_data_json = open(valid_path, 'r')
        valid_data = json.load(valid_data_json)

        test_data_json = open(test_path, 'r')
        test_data = json.load(test_data_json)

        print("すべてのデータの数は", len(train_data)+len(valid_data)+len(test_data))
        print("訓練データの数は", len(train_data))
        print("検証データの数は", len(valid_data))
        print("テストデータの数は", len(test_data)) 



        ###オリジナルデータローダーを定義

        class MyDataset(Dataset):
            def __init__(self, data):
                super().__init__()
                self.data = data
                self.len = len(data)

            def __len__(self):
                return self.len

            def __getitem__(self, index):
                img_id_list = list(self.data.keys())

                subject_data = self.data[img_id_list[index]]#インデックスで指定されたimage_idのバリューを読み込む

                position = []
                class_label = []

                ############画像の読み込み############
                img_path_list = ["./dataset/drama_image/{}.jpg".format(int(id)) for id in img_id_list]
                subject_img_path = img_path_list[index]
                image_load = cv2.imread(subject_img_path)#cv2では(高さ、幅、チャネル)で読み込まれる1　(740, 1000, 3)

                image_load = cv2.resize(image_load, dsize=(1000, 740))
                image = np.array(image_load).astype(np.float32).transpose(2, 0, 1)#image_size(3,740,1000)(チャネル数, 高さ, 幅)

                image = torch.tensor(image)/255.#0～1に正規化

                ####################################

                ############座標、画像ID、クラスラベルを読み込む############
                for one_box_data in subject_data:

                    position.append(one_box_data['xyxy'])
                    class_label.append(one_box_data['class'])

                img_id = (one_box_data['image_id'])
                input_info = [position, img_id, class_label]
                ########################################################

                ############リスク要因を読み込む############
                gt = []
                for one_box_data in subject_data:
                    one_risk_focter_data = eval(one_box_data['risk_focter'])

                    gt.append(one_risk_focter_data)
                ##########################################


                return image, None, input_info ,gt
    
        def collate_fn(batch):
        
        
        
            images, infos, gts = [], [], []
            for image, none, info, label in batch:
                image = image.unsqueeze(dim=0)
                images.append(image)
                infos.append(info)
                gts.append(label)
            images = torch.cat(images, dim=0)
            # labelsはTensorリストのまま
            return images, None, infos, gts

            
            
        #データローダーを定義


        train_dataset = MyDataset(train_data)
        valid_dataset = MyDataset(valid_data)
        test_dataset = MyDataset(test_data)

        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
        valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)




        #デー数を確認

        print("訓練データのバッチ数は",len(train_dataloader))
        print("検証データのバッチ数は",len(valid_dataloader))
        print("テストデータのバッチ数は",len(test_dataloader))




        class image_feature_create(nn.Module):
            def __init__(self, image_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.custom_resnet50 = nn.Sequential(
                                models.resnet50(pretrained=True).conv1,
                                models.resnet50(pretrained=True).bn1,
                                models.resnet50(pretrained=True).relu,
                                models.resnet50(pretrained=True).maxpool)
                
                for param in self.custom_resnet50.parameters():
                    param.requires_grad = False
                
                #self.pool = nn.MaxPool2d(2)
                #self.conv1 = nn.Conv2d(3, 32, kernel_size=(3,3), stride=(1, 1), padding=1)
                #self.conv2 = nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1, 1), padding=1)
                
            def forward(self, img):
                
                #x1 = self.pool(self.conv1(img))
                #out = self.pool(self.conv2(x1))
                out = self.custom_resnet50(img)
                
                return out


        class class_embedding(nn.Module):
            def __init__(self, dim):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                self.class_linear = nn.Linear(5, dim)
                
            def forward(self, class_label):  

                class_one_hot = torch.zeros(len(class_label), 5)
                class_one_hot[range(len(class_label)), class_label] = 1
                class_one_hot = class_one_hot.to(self.device)
                class_enb = self.class_linear(class_one_hot)
                
                out = F.relu(class_enb)
                
                return out


        class position_embedding(nn.Module):
            def __init__(self, dim):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                self.position_linear = nn.Linear(4, dim)
                
            def forward(self, position):  
                #print(position,"a")
                position = position/torch.tensor([1000, 740, 1000, 740]).to(self.device)#位置情報を0から1に正規化
                #print(position,"b")
                pos_enb = self.position_linear(position)
                
                out = F.relu(pos_enb)
                
                return out

        class time_embedding(nn.Module):
            def __init__(self, dim):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                self.linear = nn.Linear(1, dim)
                
            def forward(self, time_feat):  
                #print(position,"a")
                #print(position,"b")
                time_enb = self.linear(time_feat)
                
                out = F.relu(time_enb)
                
                return out

        class class_obj_fusion(nn.Module):
            def __init__(self, dim):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                self.concat_linear = nn.Linear(4096 * 2, dim)
                
            def forward(self, class_feature, obj_feature):  
                
                object_feature = torch.flatten(obj_feature, start_dim = 1, end_dim = 3)
                
                concat_feat = torch.cat([class_feature, object_feature], dim=1)
                x = self.concat_linear(concat_feat)
                
                out = F.relu(x)
                
                return out




        class Baseline_A_object_feature_extractor(nn.Module):
            def __init__(self, region_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding = 1)
                self.pool = nn.MaxPool2d(2)
                self.conv2 = nn.Conv2d(32, region_channel, kernel_size=(3,3), stride=(1, 1), padding = 1)
                
            def forward(self, object_feat):
                x1 =  self.conv1(object_feat)
                out = self.conv2(F.relu(self.pool(x1)))
                    
                return F.relu(out)


        class baseline_A_fusion(nn.Module):
            def __init__(self, dim):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                self.linear1 = nn.Linear(8192, 4096)
                self.linear2 = nn.Linear(4096, dim)

            def forward(self, one_img_object_feature, one_img_position_feature):
                #one_img_object_feature = torch.flatten(one_img_object_feature, start_dim = 1, end_dim = 3)
                
                one_img_obj_pos_cat_feature = torch.cat((one_img_object_feature, one_img_position_feature), dim = 1).to(self.device)

                x1 = self.linear1(one_img_obj_pos_cat_feature)
                out = self.linear2(F.relu(x1))
                
                return F.relu(out)


        class baseline_A_head(nn.Module):
            def __init__(self, class_num):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                self.predicter1 = nn.Linear(4096, 256)
                self.predicter2 = nn.Linear(256, class_num)
                
            def forward(self, x):

                x1 = self.predicter1(x)
                out = self.predicter2(F.relu(x1))
                
                return out
            




        class Baseline_B_object_feature_extractor(nn.Module):
            def __init__(self, region_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding = 1)
                self.pool = nn.MaxPool2d(2)
                self.conv2 = nn.Conv2d(32, region_channel, kernel_size=(3,3), stride=(1, 1), padding = 1)
                
            def forward(self, object_feat):
                x1 =  self.conv1(object_feat)
                out = self.conv2(F.relu(self.pool(x1)))
                    
                return F.relu(out)


        class Baseline_B_image_feature_extractor(nn.Module):
            def __init__(self, image_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.pool = nn.MaxPool2d(2)
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding=1)
                self.conv2 = nn.Conv2d(32, 16, kernel_size=(3,3), stride=(1, 1), padding=1)
                self.conv3 = nn.Conv2d(16, 16, kernel_size=(3,3), stride=(1, 1), padding=1)
                
            def forward(self, img):
                
                x1 = self.conv1(self.pool(F.relu(img)))
                x2 = self.conv2(self.pool(F.relu(x1)))
                x3 = self.conv3(self.pool(F.relu(x2)))
                
                return F.relu(x3)


        class baseline_B_fusion(nn.Module):
            def __init__(self, dim):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                self.image_linear = nn.Linear(12544, 4096)
                self.linear1 = nn.Linear(8192, 4096)
                self.linear2 = nn.Linear(4096, dim)
                
            def forward(self, one_img_object_feature, one_img_feature, one_position_feature):
                
                #one_img_object_feature = torch.flatten(one_img_object_feature, start_dim = 1, end_dim = 3)
                #print(one_img_feature.shape)
                one_img_feature = torch.flatten(one_img_feature, start_dim = 1, end_dim = 3)
                #print(one_img_feature.shape)
                one_img_feature = self.image_linear(one_img_feature)
                
                one_img_duplication_feature = one_img_feature.repeat((len(one_img_object_feature), 1))
                
                one_img_object_feature = one_img_object_feature + one_position_feature
                
                #one_img_obj_pos_cat_feature = torch.cat((one_img_object_feature, one_img_duplication_feature), dim = 1).to(self.device)
                #座標情報をconcatするなら上2行をコメントアウトして、下の行を有効にする
                one_img_obj_pos_cat_feature = torch.cat((one_img_object_feature, one_img_duplication_feature), dim = 1).to(self.device)
                
                x1 = self.linear1(one_img_obj_pos_cat_feature)
                out = self.linear2(F.relu(x1))

                return out


        class baseline_B_head(nn.Module):
            def __init__(self, class_num):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                self.predicter1 = nn.Linear(2048, 256)
                self.predicter2 = nn.Linear(256, class_num)
                
            def forward(self, x):

                x = self.predicter1(x)
                out = self.predicter2(F.relu(x))
                
                return out





        class proposal_object_feature_extractor(nn.Module):
            def __init__(self, region_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding = 1)
                self.pool = nn.MaxPool2d(2)
                self.conv2 = nn.Conv2d(32, region_channel, kernel_size=(3,3), stride=(1, 1), padding = 1)
                
            def forward(self, object_feat):
                x1 =  self.conv1(object_feat)
                out = self.conv2(F.relu(self.pool(x1)))
                    
                return F.relu(out)


        class Patching(nn.Module):
            
            def __init__(self, x_patch_size, y_patch_size):
                """ [input]
                    - patch_size (int) : パッチの縦の長さ（=横の長さ）
                """
                super().__init__()
                self.net = Rearrange("b c (h ph) (w pw) -> b (h w) c ph pw", ph = y_patch_size, pw = x_patch_size)
            
            def forward(self, x):
                """ [input]
                    - x (torch.Tensor) : 画像データ
                        - x.shape = torch.Size([batch_size, channels, image_height, image_width])
                """
                x = self.net(x)
                return x


        class proposal_patch_feature_extractor(nn.Module):
            def __init__(self, region_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding = 1)
                self.conv2 = nn.Conv2d(32, region_channel, kernel_size=(3,3), stride=(1, 1), padding = 1)
                
            def forward(self, patch_feat):
                x1 =  self.conv1(patch_feat)
                out = self.conv2(F.relu(x1))
                    
                return out


        class proposal_past_patch_feature_extractor(nn.Module):
            def __init__(self, region_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding = 1)
                self.conv2 = nn.Conv2d(32, region_channel, kernel_size=(3,3), stride=(1, 1), padding = 1)
                
            def forward(self, patch_feat):
                x1 =  self.conv1(patch_feat)
                out = self.conv2(F.relu(x1))
                    
                return out


        class proposal_past_past_patch_feature_extractor(nn.Module):
            def __init__(self, region_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding = 1)
                self.conv2 = nn.Conv2d(32, region_channel, kernel_size=(3,3), stride=(1, 1), padding = 1)
                
            def forward(self, patch_feat):
                x1 =  self.conv1(patch_feat)
                out = self.conv2(F.relu(x1))
                    
                return out


        class A_positional_encording(nn.Module):
            def __init__(self):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                
            def forward(self, object_feat, object_pos):  
                #one_img_object_feature = torch.flatten(object_feat, start_dim = 1, end_dim = 3)
                
                one_img_object_feature = object_feat + object_pos
                
                return one_img_object_feature


        class proposal_B_positional_encording(nn.Module):
            def __init__(self):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                
            def forward(self, object_feat, object_pos, patch_feat, patch_pos):  
                #one_img_object_feature = torch.flatten(object_feat, start_dim = 1, end_dim = 3)

                one_img_patch_feature = torch.flatten(patch_feat, start_dim = 1, end_dim = 3)
                
                one_img_object_feature = object_feat + object_pos
                one_img_patch_feature = one_img_patch_feature + patch_pos
                
                return one_img_object_feature, one_img_patch_feature
            

        class proposal_C_positional_encording(nn.Module):
            def __init__(self):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                
            def forward(self, object_feat, object_pos, patch_feat, patch_pos):  
                
                #one_img_patch_feature = torch.flatten(patch_feat, start_dim = 1, end_dim = 3)
                
                one_img_object_feature = object_feat + object_pos
                one_img_patch_feature = patch_feat + patch_pos
                
                return one_img_object_feature, one_img_patch_feature


        class proposal_C_image_feature_extractor(nn.Module):
            def __init__(self, image_channel):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.pool = nn.MaxPool2d(2)
                self.conv1 = nn.Conv2d(64, 32, kernel_size=(3,3), stride=(1, 1), padding=1)
                self.conv2 = nn.Conv2d(32, 16, kernel_size=(3,3), stride=(1, 1), padding=1)
                self.conv3 = nn.Conv2d(16, 16, kernel_size=(3,3), stride=(1, 1), padding=1)
                self.fc = nn.Linear(12544, 4096)
                
            def forward(self, img):
                
                x1 = self.conv1(self.pool(F.relu(img)))
                x2 = self.conv2(self.pool(F.relu(x1)))
                x3 = self.conv3(self.pool(F.relu(x2)))
        
                x4 = torch.flatten(x3, start_dim = 1, end_dim = 3)
                out = self.fc(x4)
                
                return F.relu(out)


        class Merge_past_current_images(nn.Module):
            def __init__(self):
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                super().__init__()
                
                #self.conv = nn.Conv2d(16, 16, kernel_size=(3,3), stride=(1, 1), padding = 1)
                #self.linear = nn.Linear(8192, 4096)
                #self.conv = nn.Conv2d(32, 16, kernel_size=(3,3), stride=(1, 1), padding = 1)
                

            def forward(self, now_patch_feat, past_patch_feat):
                
                #print(now_patch_feat.shape)
                #print(past_patch_feat.shape)
                now_patch_flat_feat = torch.flatten(now_patch_feat, start_dim = 1, end_dim = 3)
                past_patch_flat_feat = torch.flatten(past_patch_feat, start_dim = 1, end_dim = 3)
                #print(now_patch_flat_feat.shape)
                #print(past_patch_flat_feat.shape)
                sequence_feat = torch.cat([now_patch_flat_feat, past_patch_flat_feat], dim =1)
                #print(sequence_feat.shape)
                #marge_feat = self.linear(sequence_feat)
                #marge_feat = self.linear(sequence_feat)
                #print(marge_feat.shape)
                #now_patch_flat_feat = torch.flatten(marge_feat, start_dim = 1, end_dim = 3)
                return sequence_feat



        class TransformerEncoder(nn.Module):
            def __init__(self, dim, n_heads, mlp_dim, depth):
                """ [input]
                    - dim (int) : 各パッチのベクトルが変換されたベクトルの長さ（参考[1] (1)式 D）
                    - depth (int) : Transformer Encoder の層の深さ（参考[1] (2)式 L）
                    - n_heads (int) : Multi-Head Attention の head の数
                    - mlp_dim (int) : MLP の隠れ層のノード数
                """
                super().__init__()

                # Layers
                self.norm = nn.LayerNorm(dim)
                self.multi_head_attention = MultiHeadAttention(dim = dim, n_heads = n_heads)
                self.mlp = MLP(dim = dim, hidden_dim = mlp_dim)
                self.depth = depth#

            def forward(self, x, mode):
                """[input]
                    - x (torch.Tensor)
                        - x.shape = torch.Size([batch_size, n_patches + 1, dim])
                """

                all_layer_att_weight = []
                
                for _ in range(self.depth):
                    
                    if mode != "test":
                        att_out = self.multi_head_attention(self.norm(x), mode)
                    elif mode == "test":
                        att_out, attention_weight = self.multi_head_attention(self.norm(x), mode)
                        all_layer_att_weight.append(attention_weight.detach().cpu().numpy())
                        
                    x = att_out + x
                    #print(x.shape)
                    x = self.mlp(self.norm(x)) + x

                if mode == "test":
                    #print(all_layer_att_weight[-1].shape)
                    all_layer_att_weight = np.array(all_layer_att_weight[-1])
                    #all_layer_att_weight = np.mean(all_layer_att_weight, axis=0)
                
                if mode != "test":
                    return x
                elif mode == "test":
                    return x, all_layer_att_weight


        class MLP(nn.Module):
            def __init__(self, dim, hidden_dim):
                """ [input]
                    - dim (int) : パッチのベクトルが変換されたベクトルの長さ
                    - hidden_dim (int) : 隠れ層のノード数
                """
                super().__init__()
                self.net = nn.Sequential(
                    nn.Linear(dim, hidden_dim),
                    nn.GELU(),
                    nn.Linear(hidden_dim, dim)
                )

            def forward(self, x):
                """[input]
                    - x (torch.Tensor)
                        - x.shape = torch.Size([batch_size, n_patches + 1, dim])
                """
                x = self.net(x)
                return x


        class MultiHeadAttention(nn.Module):
            def __init__(self, dim, n_heads):
                """ [input]
                    - dim (int) : パッチのベクトルが変換されたベクトルの長さ
                    - n_heads (int) : heads の数
                """
                super().__init__()
                self.n_heads = n_heads
                self.dim_heads = dim // n_heads

                self.W_q = nn.Linear(dim, dim)
                self.W_k = nn.Linear(dim, dim)
                self.W_v = nn.Linear(dim, dim)

                self.split_into_heads = Rearrange("b n (h d) -> b h n d", h = self.n_heads)

                self.softmax = nn.Softmax(dim = -1)

                self.cat = Rearrange("b h n d -> b n (h d)", h = self.n_heads)

            def forward(self, x, mode):
                """[input]
                    - x (torch.Tensor)
                        - x.shape = torch.Size([batch_size, n_patches + 1, dim])
                """ 
                q = self.W_q(x)
                k = self.W_k(x)
                v = self.W_v(x)
                
                q = self.split_into_heads(q)#ヘッド数分、分割
                k = self.split_into_heads(k)
                v = self.split_into_heads(v)
                #print(q.shape)
                #print(k.transpose(-1, -2).shape,"k_t")
                # Logit[i] = Q[i] * tK[i] / sqrt(D) (i = 1, ... , n_heads)
                # AttentionWeight[i] = Softmax(Logit[i]) (i = 1, ... , n_heads)
                logit = torch.matmul(q, k.transpose(-1, -2)) * (self.dim_heads ** -0.5)
                #print(logit.shape,"logit")
                attention_weight = self.softmax(logit)
                #print(v.shape,"v")
                # Head[i] = AttentionWeight[i] * V[i] (i = 1, ... , n_heads)
                # Output = concat[Head[1], ... , Head[n_heads]]
                if mode == "test":
                    attention_weight_head_mean = torch.mean(attention_weight, dim = 1)
                
                output = torch.matmul(attention_weight, v)
                output = self.cat(output)
                #print(output.shape)
                if mode != "test":
                    return output
                elif mode == "test":
                    return output, attention_weight_head_mean


        class MLPHead(nn.Module):
            def __init__(self, dim, out_dim):
                super().__init__()
                self.net = nn.Sequential(
                    nn.LayerNorm(dim),
                    nn.Linear(dim, 512),
                    nn.ReLU(),
                    nn.Linear(512, out_dim)
                )
            
            def forward(self, x):
                x = self.net(x)
                return x

            
        class Risk_Facter_Estimation_Network(nn.Module):
            def __init__(self):
                
                super().__init__()
                
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
                
                self.object_size = (32, 32)
                self.region_channel = 64
                self.image_size = (896, 896)
                self.class_num = class_num
                
                if model_type == "baseline_A":
                    self.class_dim = 4096
                    self.position_dim = 4096
                    self.class_obj_fusion_dim = 4096
                    self.fusion_dim = 2048
                    self.output_height = 32
                    self.output_width = 32
                    self.object_output_size = (self.output_height, self.output_width)
                    self.cnn_cre = image_feature_create(image_channel = 64)
                    self.object_extractor = Baseline_A_object_feature_extractor(region_channel = 16)
                    self.class_encording = class_embedding(dim = self.class_dim)
                    self.position_encording = position_embedding(dim = self.position_dim)
                    self.class_obj_fusion = class_obj_fusion(dim = self.class_obj_fusion_dim)
                    self.positional_encording = A_positional_encording()
                    self.head = baseline_A_head(self.class_num)
                    
                if model_type == "baseline_B":
                    self.class_dim = 4096
                    self.position_dim = 4096
                    self.class_obj_fusion_dim = 4096
                    self.fusion_dim = 2048
                    self.output_height = 32
                    self.output_width = 32
                    self.output_size = (self.output_height, self.output_width)
                    self.cnn_cre = image_feature_create(image_channel = 64)
                    self.object_extractor = Baseline_B_object_feature_extractor(region_channel = 16)
                    self.image_extractor = Baseline_B_image_feature_extractor(image_channel = 8)
                    self.class_encording = class_embedding(dim = self.class_dim)
                    self.position_encording = position_embedding(dim = self.position_dim)
                    self.class_obj_fusion = class_obj_fusion(dim = self.class_obj_fusion_dim)
                    self.fusion = baseline_B_fusion(dim = self.fusion_dim)
                    self.head = baseline_B_head(self.class_num)

                if model_type == "proposal_A":
                    self.class_dim = 4096
                    self.position_dim = 4096
                    self.fusion_dim = 4096
                    self.img_patch_size = (16,16)
                    self.output_height = 32
                    self.output_width = 32
                    self.output_size = (self.output_height, self.output_width)
                    self.cnn_cre = image_feature_create(image_channel = 64)
                    self.object_extractor = proposal_object_feature_extractor(region_channel = 16)
                    self.patch_extractor = proposal_patch_feature_extractor(region_channel = 16)
                    self.class_encording = class_embedding(dim = self.class_dim)
                    self.position_encording = position_embedding(dim = self.position_dim)
                    self.class_obj_fusion = class_obj_fusion(dim = self.fusion_dim)
                    self.patching = Patching(self.img_patch_size[0], self.img_patch_size[1])
                    self.positional_encording = A_positional_encording()
                    self.transformer_encoder = TransformerEncoder(dim = 4096, n_heads = MHA_head_num, mlp_dim = 4096, depth = t_layer_num)
                    self.mlp_head = MLPHead(dim = 4096, out_dim = self.class_num)

                if model_type == "proposal_B":
                    self.class_dim = 4096
                    self.position_dim = 4096
                    self.fusion_dim = 4096
                    self.img_patch_size = (16,16)
                    self.output_height = 32
                    self.output_width = 32
                    self.output_size = (self.output_height, self.output_width)
                    self.cnn_cre = image_feature_create(image_channel = 64)
                    self.object_extractor = proposal_object_feature_extractor(region_channel = 16)
                    self.patch_extractor = proposal_patch_feature_extractor(region_channel = 16)
                    self.class_encording = class_embedding(dim = self.class_dim)
                    self.position_encording = position_embedding(dim = self.position_dim)
                    self.class_obj_fusion = class_obj_fusion(dim = self.fusion_dim)
                    self.patching = Patching(self.img_patch_size[0], self.img_patch_size[1])
                    self.positional_encording = proposal_B_positional_encording()
                    self.transformer_encoder = TransformerEncoder(dim = 4096, n_heads = MHA_head_num, mlp_dim = 4096, depth = t_layer_num)
                    self.mlp_head = MLPHead(dim = 4096, out_dim = self.class_num)
                
                if model_type == "proposal_C":
                    self.class_dim = 4096
                    self.position_dim = 4096
                    self.fusion_dim = 4096
                    self.img_patch_size = (16,16)
                    self.output_height = 32
                    self.output_width = 32
                    self.output_size = (self.output_height, self.output_width)
                    self.cnn_cre = image_feature_create(image_channel = 64)
                    self.object_extractor = proposal_object_feature_extractor(region_channel = 16)
                    self.image_extractor = proposal_C_image_feature_extractor(image_channel = 8)
                    self.class_encording = class_embedding(dim = self.class_dim)
                    self.position_encording = position_embedding(dim = self.position_dim)
                    self.class_obj_fusion = class_obj_fusion(dim = self.fusion_dim)

                    self.positional_encording = A_positional_encording()
                    self.transformer_encoder = TransformerEncoder(dim = 4096, n_heads = MHA_head_num, mlp_dim = 4096, depth = t_layer_num)
                    self.mlp_head = MLPHead(dim = 4096, out_dim = self.class_num)

            
            def roi_align(self, feature_map, boxes, output_size):
                """
                ROI Alignの実装
                :param feature_map: 入力の特徴マップ (Tensor: N x C x H x W)
                :param boxes: ROIの座標 (Tensor: num_boxes x 5 [batch_index, x1, y1, x2, y2])
                :param output_size: 出力のサイズ (tuple: (output_height, output_width))
                :return: ROI Alignされた特徴マップ (Tensor: num_boxes x C x output_height x output_width)
                """
                # ROI Alignのパラメータ
                spatial_scale = 1.0
                sampling_ratio = -1
                aligned = True

                # ROI Alignの実行
                return torchvision.ops.roi_align(feature_map, boxes, output_size,
                                spatial_scale=spatial_scale,
                                sampling_ratio=sampling_ratio,
                                aligned=aligned)
            
            def forward(self, img, past_img, info, mode):
                
                if model_type == "baseline_A":
                    out = []
                    batch_resized_image = vF.resize(img, size=self.image_size).to(self.device)# torch.Size([batch数, 3, 896, 896])
                    cnn_features = self.cnn_cre(batch_resized_image)# torch.Size([batch数, 64, 224, 224]) 
                    box_pos = [one_image_data[0] for one_image_data in info]# print(len(box_pos)) -> (batchsize)
                    
                    for one_image_id in range(len(cnn_features)):
                        one_image_cnn_feature = cnn_features[one_image_id]# torch.Size([64, 224, 224]) 
                        one_image_cnn_feature = torch.unsqueeze(one_image_cnn_feature, dim = 0)# torch.Size([1, 64, 224, 224]) 
                        one_image_box_pos = torch.tensor(box_pos[one_image_id])# 一つの画像のすべての物体の位置情報
                        image_rate = torch.tensor([224./1000., 224./740., 224./1000., 224./740.])
                        
                        # 224×224にリサイズしているので、物体の位置座標を224×224のレートで変換する
                        one_image_box_after_conversion_pos = [torch.tensor([[a * b for a, b in zip(sublist, image_rate)] for sublist in one_image_box_pos]).to(self.device)]
                        
                        one_img_roi_aligned_object_feature_map = self.roi_align(one_image_cnn_feature, one_image_box_after_conversion_pos, self.object_output_size)# torch.Size([検出数, 64(チャネル), 32(高さ), 32(幅)])

                        one_img_object_featueres = self.object_extractor(one_img_roi_aligned_object_feature_map)#[16(チャネル), 16(高さ), 16(幅)]
                        
                        # 物体の位置情報を処理
                        one_img_obj_position = torch.tensor(info[one_image_id][0]).to(self.device)#(検出数, 4)
                        one_img_position_features = self.position_encording(one_img_obj_position)#(検出数, 4096)

                        # 物体のクラス情報を処理
                        one_img_obj_class = torch.tensor(info[one_image_id][2]).to(self.device)#(検出数, 1)
                        one_img_class_features = self.class_encording(one_img_obj_class)#(検出数, 4096)
                        
                        class_obj_fusion_features = self.class_obj_fusion(one_img_class_features, one_img_object_featueres)#クラス情報と物体情報を結合し特徴圧縮 (検出数, 4096)
                        
                        objest_features = self.positional_encording(class_obj_fusion_features, one_img_position_features)
                        
                        one_image_out = self.head(objest_features)

                        out.append(one_image_out)
                        
                elif model_type == "baseline_B":
                    out = []
                    batch_resized_image = vF.resize(img, self.image_size).to(self.device)# torch.Size([batch数, 3, 896, 896]
                    cnn_features = self.cnn_cre(batch_resized_image)# torch.Size([batch数, 64, 224, 224]) 
                    box_pos = [one_image_data[0] for one_image_data in info]# print(len(box_pos)) -> 4(batchsize)
                    
                    for one_image_id in range(len(cnn_features)):
                        one_image_cnn_feature = cnn_features[one_image_id]# torch.Size([64, 224, 224]) 
                        one_image_cnn_feature = torch.unsqueeze(one_image_cnn_feature, dim = 0)# torch.Size([1, 64, 224, 224]) 
                        one_image_box_pos = torch.tensor(box_pos[one_image_id])# 一つの画像のすべての物体の位置情報
                        image_rate = torch.tensor([224./1000., 224./740., 224./1000., 224./740.])
                        
                        # 224×224にリサイズしているので、物体の位置座標を224×224のレートで変換する
                        one_image_box_after_conversion_pos = [torch.tensor([[a * b for a, b in zip(sublist, image_rate)] for sublist in one_image_box_pos]).to(self.device)]
                        
                        one_img_roi_aligned_object_feature_map = self.roi_align(one_image_cnn_feature, one_image_box_after_conversion_pos, self.output_size)# torch.Size([検出数, 64(チャネル), 32(高さ), 32(幅)])

                        one_img_object_featueres = self.object_extractor(one_img_roi_aligned_object_feature_map)#[16(チャネル), 16(高さ), 16(幅)]
                        
                        # 画像全体の特徴を処理
                        one_img_featueres = self.image_extractor(one_image_cnn_feature)# [4096] 後々、検出数分複製
                        
                        # 物体の位置情報を処理
                        one_img_obj_position = torch.tensor(info[one_image_id][0]).to(self.device)#(検出数, 4)
                        one_img_position_features = self.position_encording(one_img_obj_position)#(検出数, 4096)

                        # 物体のクラス情報を処理
                        one_img_obj_class = torch.tensor(info[one_image_id][2]).to(self.device)#(検出数, 1)
                        one_img_class_features = self.class_encording(one_img_obj_class)#(検出数, 4096)
                        
                        class_obj_fusion_features = self.class_obj_fusion(one_img_class_features, one_img_object_featueres)#クラス情報と物体情報を結合し特徴圧縮 (検出数, 4096)
                        
                        fusion_features = self.fusion(class_obj_fusion_features, one_img_featueres, one_img_position_features)#位置情報と画像全体情報とクラス+物体情報を結合し特徴圧縮 (検出数, 2048)

                        one_image_out = self.head(fusion_features)

                        out.append(one_image_out)

                elif model_type == "proposal_A":
                    out = []
                    batch_att = []
                    #　448, 448にサイズ変更
                    batch_resized_image = vF.resize(img, size=self.image_size).to(self.device)
                    
                    # 画像全体情報を抽出　print(cnn_features.shape) -> torch.Size([4(batch_size), 16, 224, 224])
                    cnn_features = self.cnn_cre(batch_resized_image)
                    
                    #位置情報を読み込む
                    box_pos = [one_image_data[0] for one_image_data in info]# print(len(box_pos)) -> 4(batchsize)
                    
                    #画像1枚づつ処理する
                    for one_image_id in range(len(cnn_features)):
                        
                        one_image_cnn_feature = cnn_features[one_image_id]
                        one_image_cnn_feature = torch.unsqueeze(one_image_cnn_feature, dim = 0)
                        
                        one_image_box_pos = torch.tensor(box_pos[one_image_id])
                        image_rate = torch.tensor([224./1000., 224./740., 224./1000., 224./740.])
                        one_image_box_after_conversion_pos = [torch.tensor([[a * b for a, b in zip(sublist, image_rate)] for sublist in one_image_box_pos]).to(self.device)]
                        
                        #roi_alignで、画像全体情報から物体領域を切り取る
                        one_img_roi_aligned_object_feature_map = self.roi_align(one_image_cnn_feature, one_image_box_after_conversion_pos, self.output_size)# torch.Size([検出数, チャネル, 高さ, 幅])
                        
                        #切り取った領域の特徴をCNNでさらに圧縮する
                        one_img_object_featueres = self.object_extractor(one_img_roi_aligned_object_feature_map)
                        
                        #物体の位置情報を埋め込み表現として獲得する
                        one_img_obj_position = torch.tensor(info[one_image_id][0]).to(self.device)#(検出数, 4)
                        one_img_position_features = self.position_encording(one_img_obj_position)#(検出数, 4096)

                        one_img_obj_class = torch.tensor(info[one_image_id][2]).to(self.device)#(検出数, 1)
                        one_img_class_features = self.class_encording(one_img_obj_class)#(検出数, 4096)
                        
                        class_obj_fusion_features = self.class_obj_fusion(one_img_class_features, one_img_object_featueres)
                        
                        #位置エンコーディング
                        objest_features = self.positional_encording(class_obj_fusion_features, one_img_position_features)
                        
                        #検出数は画像によりバラバラなので、20の固定数になるまで0でパディングする
                        if objest_features.shape[0] < 20:#パディング
                            while True:
                                zero_pad = torch.zeros(objest_features.shape[1]).to(self.device).unsqueeze(dim = 0)
                                
                                objest_features = torch.cat((objest_features, zero_pad), dim = 0)
                                if objest_features.shape[0] == 20:
                                    break
                        elif objest_features.shape[0] > 20:
                                    objest_features[:20,:]

                        #入力トークンを作成
                        all_token = objest_features.unsqueeze(dim = 0)
                        
                        #作成された物体の情報のトークンをTransformer_Encoderに入力する
                        if mode != "test":
                            #print(all_token.shape)
                            x = self.transformer_encoder(all_token, mode)
                        elif mode == "test":
                            x, att = self.transformer_encoder(all_token, mode)
                            batch_att.append(att)
                        
                        #出力から物体のトークンのみに絞る
                        x = x[:,:20,:]
                        
                        #ヘッドで処理をする
                        one_img_out = self.mlp_head(x).squeeze(dim = 0)

                        out.append(one_img_out)
                        
                elif model_type == "proposal_B":
                    out = []
                    batch_att = []
                    # img(torch.Size([batchsize, 3, 740, 1000]))
                    #　448, 448にサイズ変更
                    batch_resized_image = vF.resize(img, size=self.image_size).to(self.device)
                    #torch.manual_seed(0)
                    # 画像全体情報を抽出　print(cnn_features.shape) -> torch.Size([4(batch_size), 16, 224, 224])
                    cnn_features = self.cnn_cre(batch_resized_image)# print(cnn_features.shape) -> torch.Size([4(batch_size), 16, 224, 224])
                    
                    #位置情報を読み込む
                    box_pos = [one_image_data[0] for one_image_data in info]# print(len(box_pos)) -> 4(batchsize)
                    
                    #画像1枚づつ処理する
                    for one_image_id in range(len(cnn_features)):
                        
                        one_image_cnn_feature = cnn_features[one_image_id]
                        one_image_cnn_feature = torch.unsqueeze(one_image_cnn_feature, dim = 0)
                        
                        one_image_box_pos = torch.tensor(box_pos[one_image_id])
                        image_rate = torch.tensor([224./1000., 224./740., 224./1000., 224./740.])
                        one_image_box_after_conversion_pos = [torch.tensor([[a * b for a, b in zip(sublist, image_rate)] for sublist in one_image_box_pos]).to(self.device)]
                        
                        #roi_alignで、画像全体情報から物体領域を切り取る
                        one_img_roi_aligned_object_feature_map = self.roi_align(one_image_cnn_feature, one_image_box_after_conversion_pos, self.output_size)# torch.Size([検出数, チャネル, 高さ, 幅])

                        #切り取った領域の特徴をCNNでさらに圧縮する
                        one_img_object_featueres = self.object_extractor(one_img_roi_aligned_object_feature_map)

                        #物体の位置情報を埋め込み表現として獲得する
                        one_img_obj_position = torch.tensor(info[one_image_id][0]).to(self.device)#(検出数, 4)
                        one_img_position_features = self.position_encording(one_img_obj_position)#(検出数, 4096)
                        
                        one_img_obj_class = torch.tensor(info[one_image_id][2]).to(self.device)#(検出数, 1)
                        one_img_class_features = self.class_encording(one_img_obj_class)#(検出数, 4096)
                    
                        class_obj_fusion_features = self.class_obj_fusion(one_img_class_features, one_img_object_featueres)
                        
                        #画像全体特徴をパッチに分割する
                        one_img_patchs = torch.squeeze(self.patching(one_image_cnn_feature), dim= 0)#(patch_num, chanel, 縦, 幅)
                        
                        #切り取ったパッチの特徴をCNNでさらに圧縮する
                        one_img_patchs_features = self.patch_extractor(one_img_patchs)
                        #print(one_img_patchs_features.shape)
                        #パッチの位置情報を取得し、パッチ位置情報を埋め込み表現として獲得する
                        one_img_patch_pos = []
                        for y in range(224//self.img_patch_size[1]):
                            for x in range(224//self.img_patch_size[0]):
                                one_patch_pos = torch.tensor([self.img_patch_size[0] * x ,self.img_patch_size[1] * y, 
                                                            self.img_patch_size[0] * x + self.img_patch_size[0] ,self.img_patch_size[1] * y + self.img_patch_size[1]])
                                one_img_patch_pos.append(one_patch_pos)
                        one_img_patch_pos = torch.stack(one_img_patch_pos).to(self.device)
                        one_img_patchs_position_features = self.position_encording(one_img_patch_pos)#(パッチ数, 4096)
                        
                        #位置エンコーディング
                        objest_features, patch_features = self.positional_encording(class_obj_fusion_features, one_img_position_features, one_img_patchs_features, one_img_patchs_position_features)
                        
                        #検出数は画像によりバラバラなので、20の固定数になるまで0でパディングする
                        if objest_features.shape[0] < 20:#パディング
                            while True:
                                zero_pad = torch.zeros(objest_features.shape[1]).to(self.device).unsqueeze(dim = 0)
                                
                                objest_features = torch.cat((objest_features, zero_pad), dim = 0)
                                if objest_features.shape[0] == 20:
                                    break
                        elif objest_features.shape[0] > 20:
                                    objest_features[:20,:]
                    
                        #入力トークンを作成
                        all_token = torch.cat([objest_features, patch_features], dim = 0).unsqueeze(dim = 0)
                        
                        #作成された物体の情報、パッチの情報のトークンをTransformer_Encoderに入力する
                        #print(all_token.shape)
                        if mode != "test":
                            #print(all_token.shape)
                            x = self.transformer_encoder(all_token, mode)
                        elif mode == "test":
                            x, att = self.transformer_encoder(all_token, mode)
                            batch_att.append(att)
                        
                        #出力から物体のトークンのみに絞る
                        x = x[:,:20,:]
                        
                        #ヘッドで処理をする
                        one_img_out = self.mlp_head(x).squeeze(dim = 0)
                        
                        out.append(one_img_out)
                
                elif model_type == "proposal_C":
                    out = []
                    batch_att = []
                    # img(torch.Size([batchsize, 3, 740, 1000]))
                    #　448, 448にサイズ変更
                    batch_resized_image = vF.resize(img, size=self.image_size).to(self.device)
                    #torch.manual_seed(0)
                    # 画像全体情報を抽出　print(cnn_features.shape) -> torch.Size([4(batch_size), 16, 224, 224])
                    cnn_features = self.cnn_cre(batch_resized_image)# print(cnn_features.shape) -> torch.Size([4(batch_size), 16, 224, 224])
                    
                    #位置情報を読み込む
                    box_pos = [one_image_data[0] for one_image_data in info]# print(len(box_pos)) -> 4(batchsize)
                    
                    #画像1枚づつ処理する
                    for one_image_id in range(len(cnn_features)):
                        
                        one_image_cnn_feature = cnn_features[one_image_id]
                        one_image_cnn_feature = torch.unsqueeze(one_image_cnn_feature, dim = 0)
                        
                        one_image_box_pos = torch.tensor(box_pos[one_image_id])
                        image_rate = torch.tensor([224./1000., 224./740., 224./1000., 224./740.])
                        one_image_box_after_conversion_pos = [torch.tensor([[a * b for a, b in zip(sublist, image_rate)] for sublist in one_image_box_pos]).to(self.device)]
                        
                        #roi_alignで、画像全体情報から物体領域を切り取る
                        one_img_roi_aligned_object_feature_map = self.roi_align(one_image_cnn_feature, one_image_box_after_conversion_pos, self.output_size)# torch.Size([検出数, チャネル, 高さ, 幅])
        
                        #切り取った領域の特徴をCNNでさらに圧縮する
                        one_img_object_featueres = self.object_extractor(one_img_roi_aligned_object_feature_map)
        
                        #物体の位置情報を埋め込み表現として獲得する
                        one_img_obj_position = torch.tensor(info[one_image_id][0]).to(self.device)#(検出数, 4)
                        one_img_position_features = self.position_encording(one_img_obj_position)#(検出数, 4096)
                        
                        one_img_obj_class = torch.tensor(info[one_image_id][2]).to(self.device)#(検出数, 1)
                        one_img_class_features = self.class_encording(one_img_obj_class)#(検出数, 4096)
                       
                        class_obj_fusion_features = self.class_obj_fusion(one_img_class_features, one_img_object_featueres)
                        
        
                        one_img_featueres = self.image_extractor(one_image_cnn_feature)# [1,4096] 
                        
        
                        #print(one_img_patchs_features.shape)
                        #パッチの位置情報を取得し、パッチ位置情報を埋め込み表現として獲得する
                        one_img_patch_pos = []
                        for y in range(224//self.img_patch_size[1]):
                            for x in range(224//self.img_patch_size[0]):
                                one_patch_pos = torch.tensor([self.img_patch_size[0] * x ,self.img_patch_size[1] * y, 
                                                              self.img_patch_size[0] * x + self.img_patch_size[0] ,self.img_patch_size[1] * y + self.img_patch_size[1]])
                                one_img_patch_pos.append(one_patch_pos)
                        one_img_patch_pos = torch.stack(one_img_patch_pos).to(self.device)
                        one_img_patchs_position_features = self.position_encording(one_img_patch_pos)#(パッチ数, 4096)
                        
                        #位置エンコーディング
                        objest_features = self.positional_encording(class_obj_fusion_features, one_img_position_features)
                        
                        #検出数は画像によりバラバラなので、20の固定数になるまで0でパディングする
                        if objest_features.shape[0] < 20:#パディング
                            while True:
                                zero_pad = torch.zeros(objest_features.shape[1]).to(self.device).unsqueeze(dim = 0)
                                
                                objest_features = torch.cat((objest_features, zero_pad), dim = 0)
                                if objest_features.shape[0] == 20:
                                    break
                        elif objest_features.shape[0] > 20:
                                    objest_features[:20,:]

                        #入力トークンを作成
                        all_token = torch.cat([objest_features, one_img_featueres], dim = 0).unsqueeze(dim = 0)

                        #作成された物体の情報、パッチの情報のトークンをTransformer_Encoderに入力する
                        #print(all_token.shape)
                        if mode != "test":
                            #print(all_token.shape)
                            x = self.transformer_encoder(all_token, mode)
                        elif mode == "test":
                            x, att = self.transformer_encoder(all_token, mode)
                            batch_att.append(att)

                        #出力から物体のトークンのみに絞る
                        x = x[:,:20,:]

                        #ヘッドで処理をする
                        one_img_out = self.mlp_head(x).squeeze(dim = 0)

                        out.append(one_img_out)


                if mode != "test" or model_type == "baseline_A" or model_type == "baseline_B":
                    return out
                elif mode == "test":
                    print("out")
                    return out, np.array(batch_att)
                
                

        ###学習周りの定義


        warnings.simplefilter('ignore')

        device = "cuda" if torch.cuda.is_available() else "cpu"

        model = Risk_Facter_Estimation_Network().to(device)

        model.train()

        if optim_name == "SGD":
            optimizer = optim.SGD(model.parameters(),lr=learning_rate)
        elif optim_name == "Adam":
            optimizer = optim.Adam(model.parameters(),lr=learning_rate)
        elif optim_name == "RAdam":
            optimizer = optim.RAdam(model.parameters(),lr=learning_rate)
        else:
            print("指定の最適化関数を定義してください(SGD, Adam ,RAdam)")

        def loss_fn(pre, gt):

            if loss_func == "bce":
                loss_fn_b = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                #loss_fn_b = nn.BCEWithLogitsLoss()

            batch_loss_list = []

            for one_img_index in range(len(gt)):#gtが[[0,1,0,0,0,0],[],[1,0,0,0,0,0]]の場合  バッチサイズで繰り返し

                risk_off_indexes = [index for index, sub_lst in enumerate(gt[one_img_index]) if not sub_lst]#リスク要因が[]になっているインデックスを取得 [1]

                risk_on_gt_list = [sub_lst for sub_lst in gt[one_img_index] if sub_lst]#gtから空リストを除く [[0,1,0,0,0,0],[1,0,0,0,0,0]]

                if len(risk_on_gt_list) != 0:#1つの画像内に危険物体が1つでも存在する場合
                    one_gt = torch.tensor(risk_on_gt_list, dtype = torch.float32).to(device)
                    
                    one_pre = pre[one_img_index]
                    
                    if model_type == "proposal_A" or model_type == "proposal_B" or model_type == "proposal_C":
                        target_obj_index = len(gt[one_img_index])
                        one_pre = one_pre[0:target_obj_index]
                        
                    if len(risk_off_indexes) != 0:
                        
                        risk_on_pre_list = [row for i, row in enumerate(one_pre) if i not in risk_off_indexes]#gtのリスク要因が[]の場所でpreを消す
                        one_pre = torch.stack(risk_on_pre_list)

                    one_loss = loss_fn_b(one_pre,one_gt)

                else:#1つの画像内に危険物体が1つも存在しない場合
                        one_loss = torch.tensor(0.).to(device)

                batch_loss_list.append(one_loss)

            batch_loss_tensor = torch.stack(batch_loss_list)
            loss = torch.mean(batch_loss_tensor)
            #print(loss)
            
            return loss

        def batch_cm_calculation(pre, gt):
                
            batch_cm = []
            
            for one_img_index in range(len(gt)):#画像1枚づつ処理
                    
                    risk_off_indexes = [index for index, sub_lst in enumerate(gt[one_img_index]) if not sub_lst]#空リストの要素番号を記録
                    risk_on_gt_list = [sub_lst for sub_lst in gt[one_img_index] if sub_lst]#空リストを除いた正解ラベル
                    
                    target_obj_index = len(gt[one_img_index])
                    
                    if len(risk_on_gt_list) != 0:#1つの画像内に危険物体が1つでも存在する場合　空の場合実行されない
                        
                        one_gt = np.array(risk_on_gt_list)
                        one_pre = pre[one_img_index]
                        one_pre = labeling(one_pre[0:target_obj_index])
                        
                        if len(risk_off_indexes) != 0:#空リストの部分をpreリストから除く
                            
                            risk_on_pre_list = [row for i, row in enumerate(one_pre) if i not in risk_off_indexes]
                            one_pre = torch.stack(risk_on_pre_list)
                        one_pre = one_pre.detach().cpu().numpy()
                        
                        one_cm = multilabel_confusion_matrix(one_gt, one_pre)
                        batch_cm.append(one_cm)
                        

            batch_cm = np.sum(np.array(batch_cm), axis = 0)

            return batch_cm

        def evaluation(cm):
            recall = []
            precision = []
            f1 = []
            accuracy = []
            
            for i in range(len(cm)):
                one_recall = cm[i][1][1] / (cm[i][1][1] + cm[i][1][0])
                one_precision = cm[i][1][1] / (cm[i][1][1] + cm[i][0][1])
                one_f1 = 2 * cm[i][1][1] / (2 * cm[i][1][1] + cm[i][0][1] + cm[i][1][0])
                one_accuracy = (cm[i][1][1] + cm[i][0][0]) /(cm[i][1][1] + cm[i][0][0] + cm[i][0][1] + cm[i][1][0])
                
                recall.append(one_recall)
                precision.append(one_precision)
                f1.append(one_f1)
                accuracy.append(one_accuracy)
            return recall, precision, f1, accuracy

        def labeling(pre):

                pre = F.sigmoid(pre)
                
                for i in range(len(pre)):
                    pre[i] = torch.where(pre[i] >= 0.5, torch.ones_like(pre[i]), torch.zeros_like(pre[i]))
                        
                return pre

        def train(dataloader, model):
            """
            IMAGE:size (torch.Size([batch_size, 3, 740, 1000]))
            INFO:size  (batch_size)
            INFO:content [one_img[[[box1[x1,y1,x2,y2], box2[x1,y1,x2,y2]], img_id], two_img[[[]]]......]
            GT:size (batch_size)
            GT:[one_img['[risk_factor]', '[risk_factor]',...], two_img['[risk_factor]', '[risk_factor]',...]]
            """
            
            model.train()
            
            one_epo_loss = []
            
            epoch_cm = []
            
            for batch, (IMAGE, PAST_IMAGE, INFO, GT) in enumerate(dataloader):
                #print(INFO)
                pred = model(IMAGE, PAST_IMAGE, INFO, mode = "train")# predのサイズは[バッチ数, 検出数, ラベル数]
                
                one_batch_loss = loss_fn(pred, GT)
                
                # Backpropagation
                optimizer.zero_grad()
                one_batch_loss.backward()
                optimizer.step()
                
                batch_cm = batch_cm_calculation(pred, GT)
                
                epoch_cm.append(batch_cm)
                
                one_epo_loss.append(float(one_batch_loss))
                
            epoch_loss = np.average(one_epo_loss)    
                
            epoch_cm = np.sum(np.array(epoch_cm), axis = 0)
            
            #all_evaluation
            integration_label_cm = np.sum(epoch_cm, axis = 0)
            integration_label_cm = np.expand_dims(integration_label_cm, 0)
            recall, precision, f1, accuracy = evaluation(integration_label_cm)
            
            print(epoch_loss, "train_loss")
            
            return epoch_loss,    recall, precision, f1, accuracy

        def valid(dataloader, model):
            """
            IMAGE:size (torch.Size([batch_size, 3, 740, 1000]))
            INFO:size  (batch_size)
            INFO:content [one_img[[[box1[x1,y1,x2,y2], box2[x1,y1,x2,y2]], img_id], two_img[[[]]]......]
            GT:size (batch_size)
            GT:[one_img['[risk_factor]', '[risk_factor]',...], two_img['[risk_factor]', '[risk_factor]',...]]
            """
            
            model.eval()
            
            one_epo_loss = []
            
            valid_cm = []
            
            for batch, (IMAGE, PAST_IMAGE, INFO, GT) in enumerate(dataloader):
                
                pred = model(IMAGE, PAST_IMAGE, INFO, mode = "valid")
                
                one_batch_loss = loss_fn(pred, GT)
                
                batch_cm = batch_cm_calculation(pred, GT)
                
                one_epo_loss.append(float(one_batch_loss))
                
                valid_cm.append(batch_cm)
                
            valid_loss = np.average(one_epo_loss) 
            
            valid_cm = np.sum(np.array(valid_cm), axis = 0)
            
            #all_evaluation
            integration_label_cm = np.sum(valid_cm, axis = 0)
            integration_label_cm = np.expand_dims(integration_label_cm, 0)
            
            micro_recall, micro_precision, micro_f1, micro_accuracy = evaluation(integration_label_cm)
            
            #each_label_evaluation
            each_label_recall, each_label_precision, each_label_f1, each_label_accuracy = evaluation(valid_cm)
            
            macro_recall, macro_precision, macro_f1, macro_accuracy = np.nanmean(each_label_recall), np.nanmean(each_label_precision), np.nanmean(each_label_f1), np.nanmean(each_label_accuracy)
            #print(each_label_precision)
            
            print("検証macro_f1は", macro_f1)
            
            return valid_loss,   macro_recall, macro_precision, macro_f1, macro_accuracy,    micro_recall, micro_precision, micro_f1, micro_accuracy,    each_label_recall, each_label_precision, each_label_f1, each_label_accuracy,   valid_cm

        def test(dataloader, model):
            """
            IMAGE:size (torch.Size([batch_size, 3, 740, 1000]))
            INFO:size  (batch_size)
            INFO:content [one_img[[[box1[x1,y1,x2,y2], box2[x1,y1,x2,y2]], img_id], two_img[[[]]]......]
            GT:size (batch_size)
            GT:[one_img['[risk_factor]', '[risk_factor]',...], two_img['[risk_factor]', '[risk_factor]',...]]
            """
            
            model.eval()
            
            test_cm = []
            
            for batch, (IMAGE, PAST_IMAGE, INFO, GT) in enumerate(dataloader):
                
                pred = model(IMAGE, PAST_IMAGE, INFO, mode = "epoch_test")
                
                batch_cm = batch_cm_calculation(pred, GT)
                
                test_cm.append(batch_cm)
            
            test_cm = np.sum(np.array(test_cm), axis = 0)
            
            #all_evaluation
            integration_label_cm = np.sum(test_cm, axis = 0)
            integration_label_cm = np.expand_dims(integration_label_cm, 0)
            
            micro_recall, micro_precision, micro_f1, micro_accuracy = evaluation(integration_label_cm)
            
            #each_label_evaluation
            each_label_recall, each_label_precision, each_label_f1, each_label_accuracy = evaluation(test_cm)
            
            macro_recall, macro_precision, macro_f1, macro_accuracy = np.nanmean(each_label_recall), np.nanmean(each_label_precision), np.nanmean(each_label_f1), np.nanmean(each_label_accuracy)
            #print(each_label_precision)
            
            print(micro_recall, "micro_recall")
            print(micro_precision, "micro_precision")
            print(micro_f1, "micro_f1")
            print(micro_accuracy, "micro_accuracy")
            
            print(macro_recall, "macro_recall")
            print(macro_precision, "macro_precision")
            print(macro_f1, "macro_f1")
            print(macro_accuracy, "macro_accuracy")
            
            print("ラベルごとのf1")
            print("ラベル1 ", float(each_label_f1[0]))
            print("ラベル2 ", float(each_label_f1[1]))
            print("ラベル3 ", float(each_label_f1[2]))
            print("ラベル4 ", float(each_label_f1[3]))
            print("ラベル5 ", float(each_label_f1[4]))
            return macro_recall, macro_precision, macro_f1, macro_accuracy,    micro_recall, micro_precision, micro_f1, micro_accuracy,    each_label_recall, each_label_precision, each_label_f1, each_label_accuracy,   test_cm





        ### メイン


        top_t = 0

        params = 0
        for p in model.parameters():
            if p.requires_grad:
                params += p.numel()

        log_dict = {"model_type": str(model_type), "MHA_head_num":str(MHA_head_num), "t_layer_num":str(t_layer_num),
                    "params_num": str(params), "data_k":str(data_k), "epochs": str(epochs), 
                    "early_stopping_limit": str(early_stopping_limit), "mimimum_updates":str(mimimum_updates),
                    "loss_func": str(loss_func), "optimizer": str(optim_name), "learning_rate": str(learning_rate)}

        with open(f"{model_dir}/log.json", 'w') as f:
            json.dump(log_dict, f)
        print("ログを保存しました")

        early_stopping_count = 0
        early_stopping_max_f1_buffer = 0.

        train_all_loss = []
        train_epoch_recall_list = []
        train_epoch_precision_list = []
        train_epoch_f1_list = []
        train_epoch_accuracy_list = []

        valid_all_loss = []

        micro_valid_epoch_recall_list = []
        micro_valid_epoch_precision_list = []
        micro_valid_epoch_f1_list = []
        micro_valid_epoch_accuracy_list = []

        macro_valid_epoch_recall_list = []
        macro_valid_epoch_precision_list = []
        macro_valid_epoch_f1_list = []
        macro_valid_epoch_accuracy_list = []

        valid_each_label_recall_list = []
        valid_each_label_precision_list = []
        valid_each_label_f1_list = []
        valid_each_label_accuracy_list = []

        valid_cm_list = []

        micro_test_epoch_recall_list = []
        micro_test_epoch_precision_list = []
        micro_test_epoch_f1_list = []
        micro_test_epoch_accuracy_list = []

        macro_test_epoch_recall_list = []
        macro_test_epoch_precision_list = []
        macro_test_epoch_f1_list = []
        macro_test_epoch_accuracy_list = []

        test_each_label_recall_list = []
        test_each_label_precision_list = []
        test_each_label_f1_list = []
        test_each_label_accuracy_list = []

        test_cm_list = []

        print(model_type, "で学習を実行します")

        for t in range(epochs):
            print(f"Epoch {t+1}\\{epochs}-------------------------------")
            print("学習中....")#--------------------------------------------------------------------------
            start_time = time.time()

            
            model.train()
            train_epoch_loss, train_epoch_recall, train_epoch_precision, train_epoch_f1, train_epoch_accuracy = train(train_dataloader, model)
            train_all_loss.append(train_epoch_loss)
            train_epoch_recall_list.append(train_epoch_recall)
            train_epoch_precision_list.append(train_epoch_precision)
            train_epoch_f1_list.append(train_epoch_f1)
            train_epoch_accuracy_list.append(train_epoch_accuracy)
            
            end_time = time.time()
            
            
            print("検証中....")#--------------------------------------------------------------------------
            valid_epoch_loss,    macro_valid_epoch_recall, macro_valid_epoch_precision, macro_valid_epoch_f1, macro_valid_epoch_accuracy,   micro_valid_epoch_recall, micro_valid_epoch_precision, micro_valid_epoch_f1, micro_valid_epoch_accuracy,    valid_epoch_each_label_recall, valid_epoch_each_label_precision, valid_epoch_each_label_f1, valid_epoch_each_label_accuracy, valid_cm = valid(valid_dataloader, model)
            
            valid_all_loss.append(valid_epoch_loss)
            micro_valid_epoch_recall_list.append(micro_valid_epoch_recall)
            micro_valid_epoch_precision_list.append(micro_valid_epoch_precision)
            micro_valid_epoch_f1_list.append(micro_valid_epoch_f1)
            micro_valid_epoch_accuracy_list.append(micro_valid_epoch_accuracy)
            
            macro_valid_epoch_recall_list.append(macro_valid_epoch_recall)
            macro_valid_epoch_precision_list.append(macro_valid_epoch_precision)
            macro_valid_epoch_f1_list.append(macro_valid_epoch_f1)
            macro_valid_epoch_accuracy_list.append(macro_valid_epoch_accuracy)
            
            valid_each_label_recall_list.append(valid_epoch_each_label_recall)
            valid_each_label_precision_list.append(valid_epoch_each_label_precision)
            valid_each_label_f1_list.append(valid_epoch_each_label_f1)
            valid_each_label_accuracy_list.append(valid_epoch_each_label_accuracy)
            
            if isinstance(valid_cm, list):
                pass
            else:
                valid_cm = valid_cm.tolist()
            valid_cm_list.append(valid_cm)
            
            
            print("テスト中....")#--------------------------------------------------------------------------
            macro_test_epoch_recall, macro_test_epoch_precision, macro_test_epoch_f1, macro_test_epoch_accuracy,   micro_test_epoch_recall, micro_test_epoch_precision, micro_test_epoch_f1, micro_test_epoch_accuracy,    test_epoch_each_label_recall, test_epoch_each_label_precision, test_epoch_each_label_f1, test_epoch_each_label_accuracy, test_cm = test(test_dataloader, model)

            micro_test_epoch_recall_list.append(micro_test_epoch_recall)
            micro_test_epoch_precision_list.append(micro_test_epoch_precision)
            micro_test_epoch_f1_list.append(micro_test_epoch_f1)
            micro_test_epoch_accuracy_list.append(micro_test_epoch_accuracy)
            
            macro_test_epoch_recall_list.append(macro_test_epoch_recall)
            macro_test_epoch_precision_list.append(macro_test_epoch_precision)
            macro_test_epoch_f1_list.append(macro_test_epoch_f1)
            macro_test_epoch_accuracy_list.append(macro_test_epoch_accuracy)
            
            test_each_label_recall_list.append(test_epoch_each_label_recall)
            test_each_label_precision_list.append(test_epoch_each_label_precision)
            test_each_label_f1_list.append(test_epoch_each_label_f1)
            test_each_label_accuracy_list.append(test_epoch_each_label_accuracy)
            
            if isinstance(test_cm, list):
                pass
            else:
                test_cm = test_cm.tolist()
            test_cm_list.append(test_cm)
            
            
            
            print("<epoch_time> ", int(end_time) - int(start_time), " seconds")
            
            if early_stopping == True:
                f1_update_width = early_stopping_max_f1_buffer - macro_valid_epoch_f1
                
                if f1_update_width <= mimimum_updates:
                    if save_model == True:
                        if t != 0:
                            os.remove( f"{model_weight_dir}/{model_type}_{top_t}.pt")
                        torch.save(model, f"{model_weight_dir}/{model_type}_{t+1}.pt")
                        top_t = t+1
                    early_stopping_count = 0
                    early_stopping_max_f1_buffer = macro_valid_epoch_f1
                else:
                    early_stopping_count += 1
                
                print("early_stopping_count : {}/{}".format(early_stopping_count, early_stopping_limit))
                if early_stopping_count >= early_stopping_limit:
                    print("early_stop")
                    break

        # 結果の保存
        train_log_dict = {"train_loss": train_all_loss, "train_recall": train_epoch_recall_list, "train_precision": train_epoch_precision_list, 
                        "train_f1": train_epoch_f1_list, "train_accuracy":train_epoch_accuracy_list}
        with open(f"{train_log_dir}/json_log.json", 'w') as f:
            json.dump(train_log_dict, f)
            
        valid_log_dict = {"valid_loss": valid_all_loss, "micro_valid_recall": micro_valid_epoch_recall_list, "micro_valid_precision": micro_valid_epoch_precision_list, 
                        "micro_valid_f1": micro_valid_epoch_f1_list, "micro_valid_accuracy":micro_valid_epoch_accuracy_list,
                        "macro_valid_recall": macro_valid_epoch_recall_list, "macro_valid_precision": macro_valid_epoch_precision_list, 
                        "macro_valid_f1": macro_valid_epoch_f1_list, "macro_valid_accuracy":macro_valid_epoch_accuracy_list,"cm": valid_cm_list}

        with open(f"{valid_log_dir}/json_log.json", 'w') as f:
            json.dump(valid_log_dict, f)

        test_log_dict = {"micro_test_recall": micro_test_epoch_recall_list, "micro_test_precision": micro_test_epoch_precision_list, 
                        "micro_test_f1": micro_test_epoch_f1_list, "micro_test_accuracy":micro_test_epoch_accuracy_list,
                        "macro_test_recall": macro_test_epoch_recall_list, "macro_test_precision": macro_test_epoch_precision_list, 
                        "macro_test_f1": macro_test_epoch_f1_list, "macro_test_accuracy":macro_test_epoch_accuracy_list,
                        "test_each_f1_list": test_each_label_f1_list, "cm": test_cm_list}

        with open(f"{test_log_dir}/json_log.json", 'w') as f:
            json.dump(test_log_dict, f)

        print("結果を保存しました")

すべてのデータの数は 2135
訓練データの数は 1281
検証データの数は 427
テストデータの数は 427
訓練データのバッチ数は 161
検証データのバッチ数は 54
テストデータのバッチ数は 54
ログを保存しました
proposal_A で学習を実行します
Epoch 1\300-------------------------------
学習中....
0.9161243057399062 train_loss
検証中....
検証macro_f1は 0.2264297717073366
テスト中....
[0.4122137404580153] micro_recall
[0.4127866574009729] micro_precision
[0.4125] micro_f1
[0.669208211143695] micro_accuracy
0.3557611652926841 macro_recall
0.4118773946360153 macro_precision
0.23887562170377646 macro_f1
0.669208211143695 macro_accuracy
ラベルごとのf1
ラベル1  0.043243243243243246
ラベル2  0.5353383458646617
ラベル3  0.6157965194109772
ラベル4  0.0
ラベル5  0.0
<epoch_time>  79  seconds
early_stopping_count : 0/30
Epoch 2\300-------------------------------
学習中....
0.8805765787266796 train_loss
検証中....
検証macro_f1は 0.2605060774340888
テスト中....
[0.4850798056904927] micro_recall
[0.3967082860385925] micro_precision
[0.4364658133000312] micro_f1
[0.6471163245356794] micro_accuracy
0.4214548236992185 macro_recall
0.5562776566037602 macro_