In [None]:
from IPython import get_ipython
def is_env_jupyter_notebook():
    env = get_ipython().__class__.__name__
    # Jupyter
    if env == 'ZMQInteractiveShell':
        return True
    # IPython
    elif env == 'TerminalInteractiveShell':
        return False
    # Other Shell
    else:
        return False

In [None]:
if is_env_jupyter_notebook():
    import sys
    # jupyter環境で同一階層のモジュールをimportするため
    sys.path.append('.')
    # jupyter環境で一つ上の階層のモジュールをimportするため
#     sys.path.append('..')

# import

In [None]:
import os
import glob
import gc
import cv2
import numpy as np
import pandas as pd
import pydicom
import copy
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from pydicom import dcmread
from tqdm import tqdm
from pydicom.data import get_testdata_file
from resnet.resnet_for_generation_outputdim256_avgpool import generate_model
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
from torch.optim import Adam
from scipy.ndimage.interpolation import zoom
from sklearn.metrics import f1_score
from transformers import BertJapaneseTokenizer, BertModel
from utils import(
    pickle_dump,
    pickle_load,
    load_vocab,
    get_keys_from_value
)

In [None]:
# gpuを使う場合
# os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
SEED = 1234
BATCH_SIZE = 4
EPOCHS = 100
RESIZE = 256
with_cuda = True
# CNN Encoder
# model_depth,10, 18, 34, 50, 101, 152, 200
model_depth=18
# BERT
bert_hidden_dim = 768
max_length = 256
# LSTM
lstm_input_length = 256
# Adam Optimizer
lr = 1e-3
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 0.01
# Decoder
attention_dim = 512
# embed_dim = 512
decoder_dim = 256
encoder_dim = 256
dropout = 0.1
alpha_c = 1.0
desired_slice =50
DA = 128 # AttentionをNeural Networkで計算する際の重み行列のサイズ
R = 1 # Attentionを1層重ねてる
last_n_token = 5
cnn_pretrain_weight = True

# 保存場所
num = 'hoge_num'

# 事前学習させたCNNモデルのディレクトリと番号
cnn_dir = 'hoge_cnn_dir'
cnn_num = 'hoge_cnn_num'

output_path ='/hoge/output/' + num + '/'
os.makedirs(output_path, exist_ok=False)


In [None]:
dataset_path = '/hoge/dataset/'


In [None]:
# 東北大学が公開しているBERTを，所見文で追加事前学習を行なったモデル．
JR_BERT_path =  '/hoge/model_path'


In [None]:
# 事前学習させたCNNモデル
cnn_encoder_path = '/hoge/cnn_output' + '/' + cnn_dir + '/' + 'ep' + cnn_num + '_hoge.bin'


# 学習データ

In [None]:
train_path = dataset_path + 'train/'
train_dicom_path_list = pickle_load(train_path + 'hoge_dicom_path_list.pickle')
train_one_sentence_radiology_report_list = pickle_load(train_path + 'hoge_radiology_report_list.pickle')
train_dicom_file_num_list = pickle_load(train_path + 'hoge_dicom_file_num_list.pickle')
train_posi_nega_list = pickle_load(train_path + 'hoge_posi_nega_lung_region_list.pickle')


# 画像取得

In [None]:
def get_images_list(series_path_list, wc=-500, ww=700):
    images_list = []
    total_num = len(series_path_list)
    for series_path in tqdm(series_path_list, bar_format="{l_bar}{r_bar}"):
        series_images_list = []
        accession_images_list = []
#         dicom_path_list = glob.glob(series_path)
        sorted_dicom_path_list = sortFileByImagePosition(series_path + '/') # 最後の*を除く
        
        for dicom_path in sorted_dicom_path_list:
            ds = pydicom.read_file(dicom_path)
            img_array = ds.pixel_array
            img_array = transform_to_hu(ds, img_array)
            img_array = Windowing(img_array, wc, ww)
            img_array = Resize(img_array)
            img_list = img_array.tolist()
            series_images_list.append(img_list)
            
        series_images_array = np.array(series_images_list)
        # spline補間
        resized_series_images_array = resize_volume(series_images_array, desired_slice)
        resized_series_images_list = resized_series_images_array.tolist()
        accession_images_list.append([resized_series_images_list])
        images_list.extend(accession_images_list)
        
    return images_list

In [None]:
# windowingの結果が正しいのか確認するためjpgに変換
def dicom2jpg(series_path_list, save_path, wc=-500, ww=1000):
    images_list = []
    total_num = len(series_path_list)
    for series_path in tqdm(series_path_list, bar_format="{l_bar}{r_bar}"):
        series_images_list = []
        accession_images_list = []
        sorted_dicom_path_list = sortFileByImagePosition(series_path + '/') # 最後の*を除く
        
        for ind,dicom_path in enumerate(sorted_dicom_path_list):
            ds = pydicom.read_file(dicom_path)
            img_array = ds.pixel_array
            img_array = transform_to_hu(ds, img_array)
            img_array = Windowing(img_array, wc, ww)
            img_array = Resize(img_array)
            new_images_array = img_array*256
            
            jpg_save_path = save_path + str(ind) + '.jpg'
            cv2.imwrite(jpg_save_path, new_images_array)

In [None]:
# windowingの結果が正しいのか確認するためjpgに変換
# スプライン補間した場合
def dicom2jpg_spline(series_path_list, save_path, wc=-500, ww=1000):
    img_list = []
    total_num = len(series_path_list)
    for series_path in tqdm(series_path_list, bar_format="{l_bar}{r_bar}"):
        series_img_list = []
        accession_img_list = []
        sorted_dicom_path_list = sortFileByImagePosition(series_path[:-1]) # 最後の*を除く
        sorted_dicom_path_list = sortFileByImagePosition(series_path + '/')
        
        for ind,dicom_path in enumerate(sorted_dicom_path_list):
            ds = pydicom.read_file(dicom_path)
            img_array = ds.pixel_array
            img_array = transform_to_hu(ds, img_array)
            img_array = Windowing(img_array, wc, ww)
            img_array = Resize(img_array)
            new_img_array = img_array*256
            new_img_list = new_img_array.tolist()
            series_img_list.append(new_img_list)
        series_img_array = np.array(series_img_list)
        resized_series_img_array = resize_volume(series_img_array, desired_slice)
        
        for ind, array in enumerate(resized_series_img_array):
            jpg_save_path = save_path + str(ind) + '.jpg'
            cv2.imwrite(jpg_save_path, array)

In [None]:
def sortFileByImagePosition(dirname):
    files = os.listdir(dirname)
  
    filedic = {}        # ファイル名とImagePositionを代入する辞書
    filelist = []       #　ソートされた順でファイル名を代入するリスト
  
    #　DICOM画像を読み込んでファイル名とImage Positionの辞書を作成
    for i, filename in enumerate(files):
        ds = pydicom.read_file(dirname + filename)
        filedic[filename] = ds.ImagePositionPatient[2]  # 辞書に登録
#         print("MakeDic>>",i, filename, filedic[filename])
  
    #　Sort( Image Positionを降順ソートのときは -x[1]， 昇順のときは x[1] )
    fileSortedByImgpos = sorted(filedic.items(), key=lambda x: -x[1])
    for i, fname_imgPos in enumerate(fileSortedByImgpos):
        dicom_path_list = dirname + str(fname_imgPos[0])
        filelist.append(dicom_path_list)
#         print("Sorted >> ",i ,str(fname_imgPos[0]) + ": " + str(fname_imgPos[1]))
      
    return filelist

## 画像前処理

In [None]:
def Resize(img_array):
    
    img_array = cv2.resize(img_array, (RESIZE, RESIZE))
    img_array = img_array.astype(np.float32)
    
    return img_array

In [None]:
# 対象領域を見やすくするための関数
# 肺領域を見やすくするにはwc,wwがこのくらいの値だといいらしい
# window center: -500
# window width: 1000
# 画像ごとに値があるようだけど，固定値で行う
def Windowing(img_array, wc, ww):
    img_array =  (img_array - wc + ww/2) / ww
    img_array[img_array > 1] = 1
    img_array[img_array < 0] = 0
    return img_array

In [None]:
# (0028, 1052) Rescale Intercept対策
# たまに+1024されている画像がある．それを元に戻す関数
# Hounsfield Unit: HU ハウスフィールド単位に戻す
# 水:0HU, 空気:-1000HU
def transform_to_hu(medical_image, image):
    intercept = medical_image.RescaleIntercept
    slope = medical_image.RescaleSlope
    hu_image = image * slope + intercept

    return hu_image

In [None]:
# spline補間
def resize_volume(img, desired_slice=100):
    """Resize across z-axis"""
    # Get current slice
    current_slice = img.shape[0]
    
    # Compute slice factor
    slice = current_slice / desired_slice
    
    slice_factor = 1 / slice
    # Resize across z-axis
    img = zoom(img, (slice_factor, 1, 1), order=3)
    return img

In [None]:
train_images_list = get_images_list(train_dicom_path_list)

In [None]:
additional_special_tokens = ['DATE']
tokenizer = BertJapaneseTokenizer.from_pretrained(
    JR_BERT_path,
    additional_special_tokens=additional_special_tokens
)

# captionのidを作成

In [None]:
def remove_subword(text):
    new_text_list = []
    tokenized_text = tokenizer.tokenize(text)
    for word in tokenized_text:
        if '#' in word:
            new_text_list[-1] += word.replace('#','')
        else:
            new_text_list.append(word)
            
    return new_text_list

In [None]:
train_word_dic = {}
train_word_dic['<PAD>'] = 0
train_word_dic['<MASK>'] = 1
train_word_dic['<START>'] = 2
train_word_dic['<END>'] = 3
train_word_dic['<DATE>'] = 4
train_word_dic['<UNK>'] = 5
ind = 6
for text in train_one_sentence_radiology_report_list:
    remove_subword_list = remove_subword(text)
    for word in remove_subword_list:
        if word == 'DATE':
            continue
        if word not in train_word_dic:
            train_word_dic[word] = ind
            ind += 1

In [None]:
def id2word(id_):
    return keys_list[id_]

In [None]:
keys_list = list(train_word_dic.keys())

# captionのtargetを作成

In [None]:
def make_caption_id(text_list):
    ids_list = []
    for text in text_list:
        id_list = []
        id_list.append(train_word_dic['<START>'])
        remove_subword_list = remove_subword(text)
        for word in remove_subword_list:
            if word not in train_word_dic:
                if word=='DATE':
                    id_list.append(train_word_dic['<DATE>'])
                else:
                    id_list.append(train_word_dic['<UNK>'])
            else:
                id_list.append(train_word_dic[word])
        id_list.append(train_word_dic['<END>'])
        ids_list.append(id_list)
    
    return ids_list

In [None]:
train_caption_id_list = make_caption_id(train_one_sentence_radiology_report_list)

# embedding_matrix,input_ids_listを作成

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [None]:
bert = BertModel.from_pretrained(JR_BERT_path).to(device)
bert.resize_token_embeddings(len(tokenizer))
for param in bert.parameters():
    param.requires_grad = False

In [None]:
total_subword_num = 0
for sentence in train_one_sentence_radiology_report_list:
    encoded_dict = tokenizer.encode_plus(
        sentence,                      # Sentence to encode.
        add_special_tokens=False,  # '[CLS]' and '[SEP]'
        return_attention_mask=False,   # Construct attn. masks.
        #         return_tensors='pt',    # Return pytorch tensors.
        return_length=True
    )
    total_subword_num += encoded_dict['length']

In [None]:
# 0番目はpadding用，　１番目はstart_embeddings用
embedding_matrix = np.zeros((total_subword_num+2, 768), dtype=np.float32)
embedding_matrix_ind = 2
input_ids_list = []

for sentence in tqdm(train_one_sentence_radiology_report_list, bar_format="{l_bar}{r_bar}"):
    input_id_list = [0] # start_embedding用
    prev_last_subword_ind = 0
    total_subword_len = 0

    encoded_dict = tokenizer.encode_plus(
        sentence,                      # Sentence to encode.
        add_special_tokens=True,  # '[CLS]' and '[SEP]'
        return_attention_mask=True,   # Construct attn. masks.
        return_length=True
    )
    subword_token_list = []
    for id_ in encoded_dict['input_ids']:
        subword_token_list.append(tokenizer.convert_ids_to_tokens(id_))

#     print(subword_token_list)
    subword_len = 1
    for ind in range(len(subword_token_list)):
        if subword_token_list[ind + 2] == '[SEP]':
            total_subword_len += subword_len
            bert_input_ids = encoded_dict['input_ids'][:total_subword_len+1]
            token_type_ids = encoded_dict['token_type_ids'][:total_subword_len+1]
            bert_input_ids = torch.tensor([bert_input_ids]).to(device)
            token_type_ids = torch.tensor([token_type_ids]).to(device)

            bert_outputs = bert(
                bert_input_ids,
                token_type_ids=token_type_ids
            )
            bert_last_hidden_state = bert_outputs.last_hidden_state
            subword_embeddings = bert_last_hidden_state[0][prev_last_subword_ind+1:prev_last_subword_ind+subword_len+1]
            
            embeddings = torch.sum(subword_embeddings,dim=0).tolist()
            embedding_matrix[embedding_matrix_ind] = embeddings
            input_id_list.append(embedding_matrix_ind)
            
            embedding_matrix_ind += 1
            
            break
            
        if '#' in subword_token_list[ind + 2]:
            subword_len += 1
        else:
            total_subword_len += subword_len
            bert_input_ids = encoded_dict['input_ids'][:total_subword_len+1]
            token_type_ids = encoded_dict['token_type_ids'][:total_subword_len+1]
            bert_input_ids = torch.tensor([bert_input_ids]).to(device)
            token_type_ids = torch.tensor([token_type_ids]).to(device)

            bert_outputs = bert(
                bert_input_ids,
                token_type_ids=token_type_ids
            )
            bert_last_hidden_state = bert_outputs.last_hidden_state
            subword_embeddings = bert_last_hidden_state[0][prev_last_subword_ind+1:prev_last_subword_ind+subword_len+1]
            prev_last_subword_ind += subword_len
            subword_len = 1

            embeddings = torch.sum(subword_embeddings,dim=0).tolist()
            embedding_matrix[embedding_matrix_ind] = embeddings
            input_id_list.append(embedding_matrix_ind)
            
            embedding_matrix_ind += 1
    
    input_ids_list.append(input_id_list)

In [None]:
# -1以上，１未満の一様分布
np.random.seed(seed=SEED)
start_embeddings = (2 * np.random.rand(768) -1).tolist()

In [None]:
embedding_matrix[1] = start_embeddings

In [None]:
unk_list = np.zeros(768).tolist()
embedding_matrix[0] = unk_list

In [None]:
del bert
gc.collect()

In [None]:
torch.cuda.empty_cache()

# dataset作成

In [None]:
class Resnet_Dataset(torch.utils.data.Dataset):
    def __init__(self, dicom_images, bert_output_id, caption_id, dicom_file_num, posi_nega):
        self.dicom_images = dicom_images
        self.bert_output_id = bert_output_id
        self.caption_id = caption_id
        self.dicom_file_num = dicom_file_num
        self.posi_nega = posi_nega
        self.data_num = len(dicom_images)
        
    def __len__(self):
        return self.data_num
    
    def __getitem__(self,idx):
        return self.dicom_images[idx], \
                    torch.tensor(self.bert_output_id[idx]), \
                    torch.tensor(self.caption_id[idx]), \
                    [len(self.caption_id[idx])] ,\
                    self.dicom_file_num[idx] ,\
                    self.posi_nega[idx]

In [None]:
def my_collate(batch):
    resnet_input = [item[0] for item in batch]
    bert_output_id = [item[1] for item in batch]
    caption_id = [item[2] for item in batch]
    caption_len = [item[3] for item in batch]
    dicom_file_num = [item[4] for item in batch]
    posi_nega = [item[5] for item in batch]
    
    return resnet_input, bert_output_id, caption_id, caption_len, dicom_file_num, posi_nega

In [None]:
torch.manual_seed(SEED)
dataset_train = Resnet_Dataset(
    train_images_list,
    input_ids_list,
    train_caption_id_list,
    train_dicom_file_num_list,
    train_posi_nega_list
)

In [None]:
del train_images_list
del input_ids_list
del train_one_sentence_radiology_report_list
del train_caption_id_list
del train_dicom_file_num_list
del train_posi_nega_list
gc.collect()

In [None]:
train_data_loader = torch.utils.data.DataLoader(
    dataset_train, 
    batch_size=BATCH_SIZE, 
    collate_fn=my_collate,
    shuffle=True
)

# subwordのembeddingsを足す

In [None]:
def add_subword_embeddings(bert_last_hidden_state, target_id_captions, subword_id_captions, start_embedding):
    batch_ind = 0
    embeddings = []
    for target_id_caption, subword_id_caption in zip(target_id_captions, subword_id_captions):
        tokens_embedding = []
        total_token_num = 0
        tokens_embedding.append(start_embedding)
        subword_tokenized_caption = tokenizer.convert_ids_to_tokens(subword_id_caption[1:-1]) # [CLS],[SEP]を無視
        for target_id in target_id_caption[1:-1]: # [START],[END]を無視
            current_token = ''
            subword_num = 0
            target_token = get_keys_from_value(train_word_dic, target_id)[0]
            
            for ind, _ in enumerate(subword_tokenized_caption):
                subword_token = subword_tokenized_caption[ind + total_token_num]
                if subword_token == '[PAD]':
                    break
                piece_embedding = bert_last_hidden_state[batch_ind][ind + total_token_num + 1] # [CLS]文を+1

                if subword_token == target_token:
                    tokens_embedding.append(piece_embedding)
                    total_token_num += 1
                    break
                else:
                    subword_num += 1
                    
                    if current_token == '':
                        tokens_embedding.append(piece_embedding)
                        current_token += subword_token.replace('#', '')
                        
                    else:
                        tokens_embedding[-1] += piece_embedding
                        current_token += subword_token.replace('#', '')
                        
                        if current_token == target_token:
                            total_token_num += subword_num
                            break
        tokens_embedding = torch.stack(tokens_embedding) # -> tensor([[]])
        
        packed = pack_sequence([tokens_embedding])
        pad_sequence, _ = pad_packed_sequence(packed, batch_first=True, total_length=lstm_input_length)
        squeezed_sequence = torch.squeeze(pad_sequence) # (1, lstm_input_length, 768) -> (lstm_input_length, 768)
        embeddings.append(squeezed_sequence)
        batch_ind += 1
        
    embeddings = torch.stack(embeddings)
    
    return embeddings
    

# モデル

In [None]:
def last_n_token_tensor(seq_tensor, length_list, last_n_token):
    new_seq_list = []
    max_ = seq_tensor.shape[1]
    
    for seq, len_ in zip(seq_tensor, length_list):
        packed = pack_sequence([seq[len_- last_n_token:]])
        pad, _ = pad_packed_sequence(packed, batch_first=True, total_length=max_, padding_value=0)
        new_seq_list.append(pad[0])
    
    new_seq_tensor = torch.stack(new_seq_list)
    return new_seq_tensor

In [None]:
class Classification(nn.Module):

    def __init__(self, hidden=encoder_dim, n_classes=2):
        super().__init__()
        self.linear = nn.Linear(hidden, n_classes)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, x):
        batch_size = x.size(0)
        encoder_dim = x.size(-1)
        # Flatten image
        x = x.view(batch_size, -1, encoder_dim)
        # global average pooling
        x = x.mean(dim=1)
        x = self.linear(x)

        return x

In [None]:
class CNN_Encoder(nn.Module):

    def __init__(self, device, model_depth, encoded_image_size=14, bert_requires_grad=False):
        super(CNN_Encoder, self).__init__()
        self.device = device
        self.enc_image_size = encoded_image_size

        # model_depth,10, 18, 34, 50, 101, 152, 200
        self.resnet = generate_model(encoded_image_size, model_depth)

    def forward(self, images):
        encoder_out = self.resnet(images) # (batch, encoder_dim, encoded_image_size, encoded_image_size)
        encoder_out = encoder_out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size,  encoder_dim)
        return encoder_out

In [None]:
class BERT_Decoder(nn.Module):

    def __init__(self, bert, device, bert_requires_grad=False):
        super(BERT_Decoder, self).__init__()

        self.bert = bert
        # bert_requires_grad=FalseならBERTの重み固定
        if not bert_requires_grad:
            for param in self.bert.parameters():
                param.requires_grad = False
        
        self.device = device

    def forward(self, captions):
        
        bert_input_ids = []
        attention_masks = []
        token_type_ids = []
        for sent in captions:
            encoded_dict = tokenizer.encode_plus(
                sent,                      # Sentence to encode.
                add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
                max_length=max_length,           # Pad & truncate all sentences.
                pad_to_max_length=True,
                return_attention_mask=True,   # Construct attn. masks.
                return_tensors='pt',     # Return pytorch tensors.
            )
            bert_input_ids.append(encoded_dict['input_ids'])
            attention_masks.append(encoded_dict['attention_mask'])
            token_type_ids.append(encoded_dict['token_type_ids'])
        bert_input_ids = torch.cat(bert_input_ids, dim=0).to(self.device)
        attention_masks = torch.cat(attention_masks, dim=0).to(self.device)
        token_type_ids = torch.cat(token_type_ids, dim=0).to(self.device)

        with torch.no_grad():
            bert_outputs = self.bert(
                bert_input_ids,
                attention_mask=attention_masks,
                token_type_ids=token_type_ids
            )

        return bert_outputs, bert_input_ids

In [None]:
class Attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

In [None]:
class Decoder(nn.Module):

    def __init__(
        self,
        attention_dim,
        embed_dim, 
        decoder_dim, 
        decoder_vocab_size,
        device,
        encoder_dim=512, 
        dropout=0.1,
        subword_num=32000
    ):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(Decoder, self).__init__()
        
        self.device = device

        self.vocab_size = decoder_vocab_size
        
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim).to(self.device)  # attention network
        
        self.dropout = nn.Dropout(p=dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim + 1, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell

        # init_hidden_stateではこっちがいいのかも
#         self.init_h = nn.Linear(num_pixels, decoder_dim)  
#         self.init_c = nn.Linear(num_pixels, decoder_dim)  
        
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, self.vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

        self.embedding = nn.Embedding(subword_num, embed_dim).to(self.device)
        self.init_embeddings(embedding_matrix)
        
    def init_embeddings(self, weights):
        self.embedding.weight = nn.Parameter(torch.from_numpy(weights))
        
        #トレーニング中，重みを更新させない
        self.embedding.weight.requires_grad = False
        
    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        
    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        # これいいのかな？
        # 一つの画像全体を平均させている？
        # いいっぽい．やっていることはglobal average pooling
        mean_encoder_out = encoder_out.mean(dim=1)
        
        # 各画像の一致するピクセルごとを平均した方がいいのではないか？
        # こっちでも良さそうだけど，上のやつでもいいっぽい
#         mean_encoder_out = encoder_out.mean(dim=-1) # (batch, num_pixels)
        # これをnum_pixels -> decoder_dimとなる全結合を行えばいいと思う)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, bert_output_id, captions_id, caption_lengths, posi_nega):

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        
        encoder_out = encoder_out[sort_ind]
        bert_output_id = bert_output_id[sort_ind]
        posi_nega = posi_nega.unsqueeze(dim=-1)
        posi_nega = posi_nega[sort_ind].float()

        lstm_inputs = self.embedding(bert_output_id)
    
        pad_caps_sorted = pad_sequence(captions_id, batch_first=True)
        pad_caps_sorted = pad_caps_sorted[sort_ind]
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()
#         print(decode_lengths)

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size)#.to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels)#.to(device)

        for t in range(max(decode_lengths)):
            # captions_idは降順になっている．
            # そのためbatch_size_tだけ，まだ単語が残っているのでhoge[:batch_size_t]としている．
            # 効率よく計算するためかな...
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            # attentionあり
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            # embeddingにgateをconcatするのは少し違和感
            # こういうタイプのlstmもあるのかな...
            # ちがう，attention重みつきencodingとhを融合させるためかな
            h, c = self.decode_step(
                torch.cat([posi_nega[:batch_size_t], lstm_inputs[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, pad_caps_sorted, decode_lengths,  sort_ind, alphas

In [None]:
class CNN_BERT_LSTM_Trainer:

    def __init__(
        self,
        train_dataloader: DataLoader,
        dev_dataloader: DataLoader = None,
        with_cuda: bool = True,
        lr: float = 1e-4,
        betas=(0.9, 0.999),
        weight_decay: float = 0.0,
        attention_dim: int = 512,
        embed_dim: int = 768,
        decoder_dim: int = 512,
        decoder_vocab_size: int = 32000,
        encoder_dim: int = 512,
        dropout: float = 0.1,
        output_path=None,
        subword_num: int = 32000,
        model_depth: int = 18,
        last_n_token: int = 5,
        cnn_pretrain_weight: bool = False
    ):
        # set up cuda device
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device('cuda' if cuda_condition else "cpu")

        self.cnn_encoder = CNN_Encoder(
            self.device, model_depth).to(self.device)
        
        if cnn_pretrain_weight:
            self.cnn_encoder.load_state_dict(torch.load(cnn_encoder_path))

        self.classification = Classification(encoder_dim).to(self.device)

        self.decoder = Decoder(
            attention_dim,
            embed_dim,
            decoder_dim,
            decoder_vocab_size,
            self.device,
            encoder_dim,
            dropout,
            subword_num
        ).to(self.device)

        self.train_dataloader = train_dataloader
        if dev_dataloader:
            self.dev_dataloader = dev_dataloader

        self.cnn_encoder_optimizer = Adam(
            self.cnn_encoder.parameters(),
            lr=lr,
            betas=betas,
            weight_decay=weight_decay
        )

        self.decoder_optimizer = Adam(
            self.decoder.parameters(),
            lr=lr,
            betas=betas,
            weight_decay=weight_decay
        )

        # loss function
        self.decoder_criterion = nn.CrossEntropyLoss()
        self.encoder_criterion = nn.CrossEntropyLoss()
        self.last_n_token_criterion = nn.CrossEntropyLoss()

        # prepare save file
        acc_locc_train_dic = {'loss': [], 'bleu': [], 'image_binary_acc': [], 'f1': []}
        self.output_train_path = output_path+'output_train.pickle'
        pickle_dump(acc_locc_train_dic, self.output_train_path)

        if dev_dataloader:
            acc_locc_dev_dic = {'loss': [], 'bleu': []}
            self.output_dev_path = output_path+'output_dev.pickle'
            pickle_dump(acc_locc_dev_dic, self.output_dev_path)

    def train(self, epoch):

        self.cnn_encoder.train()
        self.decoder.train()

        data_iter = tqdm(
            enumerate(self.train_dataloader),
            desc="EP_%s:%d" % ('train', epoch),
            total=len(self.train_dataloader),
            bar_format="{l_bar}{r_bar}"
        )

        avg_loss = 0.0
        references = []  # true captions for calculating BLEU-4 score
        hypotheses = []  # predictions
        total_image_binary_acc = 0
        image_binary_element = 0
        binary_output_list = []
        target_posi_nega_list = []

        for i, batch in data_iter:
            images_list, bert_output_id, captions_id, caption_lengths, dicom_file_num, posi_nega_list = batch
            images = torch.tensor(
                images_list, dtype=torch.float32).to(self.device)
            bert_output_id = pad_sequence(
                bert_output_id, batch_first=True).to(self.device)
            caption_lengths = torch.tensor(caption_lengths).to(self.device)
            posi_nega = torch.tensor(posi_nega_list).to(self.device)

            # Encoder
            encoder_out = self.cnn_encoder(images)

            # Image Classification
            classification_output = self.classification(encoder_out)

            # Decoder
            scores, caps_sorted, decode_lengths, sort_ind, alphas = self.decoder(
                encoder_out,
                bert_output_id,
                captions_id,
                caption_lengths,
                posi_nega
            )

            scores_copy = scores.detach().clone()
            
            targets = caps_sorted[:, 1:]

            scores = pack_padded_sequence(
                scores,
                decode_lengths,
                batch_first=True
            ).data.to(self.device)

            targets = pack_padded_sequence(
                targets,
                decode_lengths,
                batch_first=True
            ).data.to(self.device)

            encoder_loss = self.encoder_criterion(classification_output, posi_nega)
            decoder_loss = self.decoder_criterion(scores, targets)

            loss = encoder_loss + decoder_loss #+ last_n_token_decoder_loss

            self.decoder_optimizer.zero_grad()
            self.cnn_encoder_optimizer.zero_grad()
            loss.backward()
            self.decoder_optimizer.step()
            self.cnn_encoder_optimizer.step()

            avg_loss += loss.item()

            # references
            img_caps_list = caps_sorted.tolist()
            img_captions = [[[id2word(id_) for id_ in img_cap_list if id_ not in (
                train_word_dic['<START>'], train_word_dic['<PAD>'])]] for img_cap_list in img_caps_list]
            references.extend(img_captions)

            # hypothesis
            _, preds = torch.max(scores_copy, dim=2)
            preds_list = preds.tolist()
            temp_preds_list = []
            for j in range(len(preds_list)):
                temp_preds_list.append(
                    [id2word(id_) for id_ in preds_list[j][:decode_lengths[j]]])
            hypotheses.extend(temp_preds_list)

            # image classification acc
            image_binary_acc = classification_output.argmax(
                dim=-1).eq(posi_nega).sum().item()
            total_image_binary_acc += image_binary_acc
            image_binary_element += posi_nega.nelement()
            
            # 画像分類の経過を観察
            binary_output = classification_output.argmax(dim=-1).tolist()
            binary_output_list.extend(binary_output)
            target_posi_nega_list.extend(posi_nega.tolist())

        bleu4 = corpus_bleu(
            references,
            hypotheses,
            smoothing_function=SmoothingFunction().method7
        )

        avg_loss = avg_loss / len(data_iter)
        total_image_binary_acc = total_image_binary_acc / image_binary_element
        f1 = f1_score(target_posi_nega_list, binary_output_list)

        self.save_loss_bleu(avg_loss, self.output_train_path, bleu4, total_image_binary_acc, f1)

    def dev(self, epoch):

        self.cnn_encoder.eval()
        self.decoder.eval()

        data_iter = tqdm(
            enumerate(self.dev_dataloader),
            desc="EP_%s:%d" % ('dev', epoch),
            total=len(self.dev_dataloader),
            bar_format="{l_bar}{r_bar}"
        )

        avg_loss = 0.0

        references = []  # true captions for calculating BLEU-4 score
        hypotheses = []  # predictions

        with torch.no_grad():
            for i, batch in data_iter:
                images_list, captions, captions_id, caption_lengths, dicom_file_num = batch
                max_len = max(dicom_file_num)

                images = torch.tensor(
                    images_list, dtype=torch.float32).to(self.device)
                caption_lengths = torch.tensor(caption_lengths).to(self.device)

                encoder_out = self.cnn_encoder(images)
                bert_outputs, bert_input_ids = self.bert_decoder(captions)
                bert_last_hidden_state = bert_outputs.last_hidden_state

                scores, caps_sorted, decode_lengths, sort_ind, alphas = self.decoder(
                    encoder_out,
                    bert_last_hidden_state,
                    captions_id,
                    caption_lengths,
                    bert_input_ids
                )

                targets = caps_sorted[:, 1:]

                scores_copy = scores.detach().clone()

                scores = pack_padded_sequence(
                    scores,
                    decode_lengths,
                    batch_first=True
                ).data.to(self.device)

                targets = pack_padded_sequence(
                    targets,
                    decode_lengths,
                    batch_first=True
                ).data.to(self.device)

                loss = self.decoder_criterion(scores, targets)

                avg_loss += loss.item()

                # references
                img_caps_list = caps_sorted.tolist()
                img_captions = [[[id2word(id_) for id_ in img_cap_list if id_ not in (
                    train_word_dic['<START>'], train_word_dic['<PAD>'])]] for img_cap_list in img_caps_list]
                references.extend(img_captions)

                # hypothesis
                _, preds = torch.max(scores_copy, dim=2)
                preds_list = preds.tolist()
                temp_preds_list = []
                for j in range(len(preds_list)):
                    temp_preds_list.append(
                        [id2word(id_) for id_ in preds_list[j][:decode_lengths[j]]])
                hypotheses.extend(temp_preds_list)

        avg_loss = avg_loss / len(data_iter)

        bleu4 = corpus_bleu(
            references,
            hypotheses,
            smoothing_function=SmoothingFunction().method7
        )

        self.save_loss_bleu(avg_loss, self.output_dev_path, bleu4)
        return bleu4

    def save_model(self, epoch, file_path="output/bert_trained.model"):
        """
        Saving the current BERT model on file_path
        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        output_path = file_path + "ep%d" % epoch
        torch.save(self.cnn_encoder.to('cpu').state_dict(),
                   output_path + '_cnn_encoder.bin')
        torch.save(self.decoder.to('cpu').state_dict(),
                   output_path + '_decoder.bin')
        torch.save(self.classification.to('cpu').state_dict(),
                   output_path + '_classification.bin')
        self.cnn_encoder.to(self.device)
        self.decoder.to(self.device)
        self.classification.to(self.device)

        return output_path

#     pickleでの追加がないので，読み込んで，追加で書き込んで，もう一度保存する形で追加する
    def save_loss_bleu(self, loss, output_path, bleu, image_binary_acc, f1):
        loss_bleu_acc_dic = pickle_load(output_path)
        loss_bleu_acc_dic['loss'].append(loss)
        loss_bleu_acc_dic['bleu'].append(bleu)
        loss_bleu_acc_dic['image_binary_acc'].append(image_binary_acc)
        loss_bleu_acc_dic['f1'].append(f1)

        pickle_dump(loss_bleu_acc_dic, output_path)

In [None]:
torch.manual_seed(SEED)
trainer = CNN_BERT_LSTM_Trainer(
    train_dataloader=train_data_loader,
    with_cuda=with_cuda,
    lr=lr,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    attention_dim=attention_dim,
    embed_dim=bert_hidden_dim,
    decoder_dim=decoder_dim,
    decoder_vocab_size=len(train_word_dic),
    encoder_dim=encoder_dim,
    dropout=dropout,
    output_path=output_path,
    subword_num=total_subword_num,
    model_depth=model_depth,
    last_n_token=last_n_token,
    cnn_pretrain_weight=cnn_pretrain_weight
)
print('START')

# 学習

In [None]:
best_bleu = 0.0
for epoch in range(EPOCHS):
    trainer.train(epoch)
    trainer.save_model(epoch, output_path)