In [1]:
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, re
import nltk
from collections import Counter
import pickle
import torch.nn as nn
import matplotlib.image as mpimg

In [2]:
# load data
projections = pd.read_csv('../input/chest-xrays-indiana-university/indiana_projections.csv')
reports = pd.read_csv('../input/chest-xrays-indiana-university/indiana_reports.csv')

In [3]:
tag_encoder = {} 
index_tag = {}
i = 0
for index, row in reports.iterrows():
    tags = row['Problems'].split(';')
    for tag in tags:
        tag = tag.lower().strip()
        if tag not in tag_encoder:
            tag_encoder[tag] = i
            index_tag[i] = tag
            i += 1  
onehot_tags = []

# transform to one-hot
for index, row in reports.iterrows():
    curr = ["0"] * len(tag_encoder)
    tags = row['Problems'].split(';')
    for tag in tags:
        tag = tag.lower().strip()
        curr[tag_encoder[tag]] = "1"
    onehot_tags.append("".join(curr))
context = []
for index, row in reports.iterrows():
    context.append(str(row['impression']) + " " + str(row['findings']))
reports['context'] = context # report part
reports['tag'] = onehot_tags
data = projections.set_index('uid').join(reports.set_index('uid'))
filename = list(data['filename'])
context = list(data['context'])
tags = list(data['tag'])
output = pd.DataFrame({'filename_1': filename, 'context':context, 'tags':tags})
train = output.loc[:5999]
test = output.loc[6000:3499]
train.to_csv('./train.tsv', sep="\t", index=False, header=None)
test.to_csv('./test.tsv', sep="\t", index=False, header=None)

In [4]:
data_2 = pd.merge(data, data, how= 'inner',left_index=True,right_index=True)
data_2 = data_2[['filename_x', 'projection_x', 'filename_y', 'projection_y', 'context_x', 'tag_x']]
data_2 = data_2[data_2['projection_x']=='Frontal']
data_2 = data_2[data_2['projection_y']=='Lateral']
filename1 = list(data_2['filename_x'])
filename2 = list(data_2['filename_y'])
context = list(data_2['context_x'])
tags = list(data_2['tag_x'])
output_2 = pd.DataFrame({'filename_1': filename1, 'filename_2': filename2, 'context':context, 'tags':tags})
train_2 = output_2.loc[:2799] 
test_2 = output_2.loc[2800:] # 测试集
train_2.to_csv('./train_2.tsv', sep="\t", index=False, header=None)
test_2.to_csv('./test_2.tsv', sep="\t", index=False, header=None)

In [5]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {} 
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)


def build_vocab(captions, threshold):
    counter = Counter()
    for i in range(len(captions)):
        caption = captions[i]
        for j in range(len(caption)):
            tokens = nltk.tokenize.word_tokenize(caption[j].lower()) # tokenize
            counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized the captions.".format(i+1, len(captions)))
            
    words = [word for word, cnt in counter.items() if cnt >= threshold] # least frequence

    vocab = Vocabulary()
    
    # some special signals
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    for i, word in enumerate(words):
        vocab.add_word(word)
    
    return vocab


In [6]:
def create_captions(filepath):
    """clean and split"""

    # 去除标点
    bioclean = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{},0-9]', '', t.replace('"', '').replace('/', '').replace('\\', '').replace("'",'').strip().lower()).split()

    captions = []

    with open(filepath, "r") as file:
        for line in file:
            line = line.replace("\n", "").split("\t")
            
            sentence_tokens = []
            
            for sentence in line[-2].split("."): 
                tokens = bioclean(sentence)
                if len(tokens) == 0:
                    continue
                caption = " ".join(tokens)
                sentence_tokens.append(caption)
            
            captions.append(sentence_tokens)
    
    return captions 

class iuxray(Dataset):
    """
    prepare for training and test data
    """
    def __init__(self, root_dir, tsv_path, image_path, vocab = None, transform=None, k=10):
        self.root_dir = root_dir 
        self.tsv_path = tsv_path
        self.image_path = image_path
        
        tsv_file = os.path.join(self.root_dir, self.tsv_path)
        
        self.captions = create_captions(tsv_file)
        if vocab is None:
            self.vocab = build_vocab(self.captions, 3)
        else:
            self.vocab = vocab
        self.data_file = pd.read_csv(tsv_file, delimiter='\t',encoding='utf-8', header=None)
        self.transform = transform 
        self.k = k # top k tags

    def __len__(self):
        return len(self.data_file)

    def __getitem__(self, idx):
        if self.data_file.shape[1] == 3:
            img_name = os.path.join(self.root_dir, self.image_path, self.data_file.iloc[idx, 0])
            image = Image.open(img_name) 

            if self.transform:
                image = self.transform(image)

            caption = self.captions[idx] 
            sentences = []

            # add start and end signal
            for i in range(min(self.k,len(caption))):
                tokens = nltk.tokenize.word_tokenize(str(caption[i]).lower())
                sentence = []
                sentence.append(self.vocab('<start>'))
                sentence.extend([self.vocab(token) for token in tokens])
                sentence.append(self.vocab('<end>'))
                sentences.append(sentence)
            
            # corresponds to the length of the longest sentence in the report
            max_sent_len = max([len(sentences[i]) for i in range(len(sentences))])
            
            # Complete to the longest sentence length
            for i in range(len(sentences)):
                if len(sentences[i]) < max_sent_len:
                    sentences[i] = sentences[i] + (max_sent_len - len(sentences[i]))* [self.vocab('<pad>')]

            target = torch.Tensor(sentences) 
            tag = self.data_file.iloc[idx, 2] # tag的one hot
            return image, target, tag, self.k, max_sent_len # len_sentences = k
        
        elif self.data_file.shape[1] == 4:
            img_name1 = os.path.join(self.root_dir, self.image_path, self.data_file.iloc[idx, 0])
            img_name2 = os.path.join(self.root_dir, self.image_path, self.data_file.iloc[idx, 1])
            image1 = Image.open(img_name1) # image data
            image2 = Image.open(img_name2) 
            
            if self.transform:
                image1 = self.transform(image1)
                image2 = self.transform(image2)

            caption = self.captions[idx] 
            sentences = []

            # add start and end signal
            for i in range(min(self.k,len(caption))):
                tokens = nltk.tokenize.word_tokenize(str(caption[i]).lower())
                sentence = []
                sentence.append(self.vocab('<start>'))
                sentence.extend([self.vocab(token) for token in tokens])
                sentence.append(self.vocab('<end>'))
                sentences.append(sentence)

            # corresponds to the length of the longest sentence in the report
            max_sent_len = max([len(sentences[i]) for i in range(len(sentences))])

            # Complete to the longest sentence length
            for i in range(len(sentences)):
                if len(sentences[i]) < max_sent_len:
                    sentences[i] = sentences[i] + (max_sent_len - len(sentences[i]))* [self.vocab('<pad>')]

            target = torch.Tensor(sentences) 
            tag = self.data_file.iloc[idx, -1] # tag的one hot
            return image1, image2, target, tag, self.k, max_sent_len # len_sentences = k
            


def collate_fn(data):
    """
    自定义batch
    """
    
    if len(list(zip(*data))) == 5:
        images, captions, tags, len_sentences, max_sent_len = zip(*data)
        images = torch.stack(images, 0) 

        targets = torch.zeros(len(captions), max(len_sentences), max(max_sent_len)).long() # (batch_size, k, max_sent_len)
        prob = torch.zeros(len(captions), max(len_sentences)).long() # (batch_size, k)
        for i, cap in enumerate(captions):
            for j, sent in enumerate(cap):
                targets[i, j, :len(sent)] = sent[:] # report
                prob[i, j] = 1 # real sentence generate prob 
        tag = torch.zeros(len(tags), len(tags[0])).long()
        for i in range(len(tags)):
            for j in range(len(tags[0])):
                tag[i, j] = int(tags[i][j])
        return images, tag, targets, prob
    else:
        images1, images2, captions, tags, len_sentences, max_sent_len = zip(*data)
        images1 = torch.stack(images1, 0) 
        images2 = torch.stack(images2, 0)
        targets = torch.zeros(len(captions), max(len_sentences), max(max_sent_len)).long() # (batch_size, k, max_sent_len)
        prob = torch.zeros(len(captions), max(len_sentences)).long() # (batch_size, k)
        for i, cap in enumerate(captions):
            for j, sent in enumerate(cap):
                targets[i, j, :len(sent)] = sent[:] # report
                prob[i, j] = 1 # real sentence generate prob
        tag = torch.zeros(len(tags), len(tags[0])).long()
        for i in range(len(tags)):
            for j in range(len(tags[0])):
                tag[i, j] = int(tags[i][j])
        return images1, images2, tag, targets, prob
        

In [7]:
def get_loader(root_dir, tsv_path, image_path, transform, batch_size, shuffle, num_workers, vocab = None):
    """return torch.utils.data.DataLoader"""
    data = iuxray(root_dir = root_dir, 
             tsv_path = tsv_path, 
             image_path = image_path,
             vocab = vocab,
             transform = transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=data, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)

    return data_loader, data.vocab

In [8]:
import torchvision.models as models
# CNN encoder
class EncoderCNN(nn.Module):
    def __init__(self, model_name='resnet50', pretrained=False): 
        super(EncoderCNN, self).__init__()
        self.model_name = model_name
        self.pretrained = pretrained
        self.model, self.out_features, self.avg_func, self.bn, self.linear = self.__get_model()
        self.activation = nn.ReLU()

    def __get_model(self):
        model = None
        out_features = None
        func = None
        first_layer = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # our images are 1 channel
        modules = [first_layer]
        if self.model_name == 'resnet50':
            resnet = models.resnet50(pretrained=self.pretrained)
            modules.extend(list(resnet.children())[1:-2])
            model = nn.Sequential(*modules)
            out_features = resnet.fc.in_features
            func = torch.nn.AvgPool2d(kernel_size=14, stride=1)
        elif self.model_name == 'densenet201':
            densenet = models.densenet201(pretrained=self.pretrained)
            modules.extend(list(densenet.features[1:]))
            model = nn.Sequential(*modules)
            func = torch.nn.AvgPool2d(kernel_size=14, stride=1)
            out_features = densenet.classifier.in_features
        elif self.model_name == 'vgg19':
            vgg = models.vgg19(pretrained=self.pretrained)
            modules.extend(list(vgg.features.children())[1:-2])
            model = nn.Sequential(*modules)
            func = torch.nn.AvgPool2d(kernel_size=14, stride=1)
            out_features = list(vgg.features.children())[-3].weight.shape[0]
        linear = nn.Linear(in_features=out_features, out_features=out_features)
        bn = nn.BatchNorm1d(num_features=out_features, momentum=0.1)
        return model, out_features, func, bn, linear

    def forward(self, images):
        visual_features = self.model(images)
        avg_features = self.avg_func(visual_features).squeeze()
        return visual_features, avg_features

# MLC for tag，input: visual features, output: prob of each tags and top k tags
class MLC(nn.Module):
    def __init__(self, classes=119, sementic_features_dim=512, fc_in_features=2048, k=10):
        super(MLC, self).__init__()
        self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
        self.embed = nn.Embedding(classes, sementic_features_dim)
        self.k = k
        self.softmax = nn.Softmax(dim=-1)
        self.sem_dim = sementic_features_dim
        self.__init_weight()

    def __init_weight(self):
        self.classifier.weight.data.uniform_(-0.1, 0.1)
        self.classifier.bias.data.fill_(0)

    def forward(self, avg_features):
        tags = self.softmax(self.classifier(avg_features))
        semantic_features = self.embed(torch.topk(tags, self.k)[1])
        return tags, semantic_features

In [9]:
class AttentionVisual(nn.Module):
    def __init__(self, vis_enc_dim, sent_hidden_dim, att_dim):
        super(AttentionVisual, self).__init__()

        self.enc_att = nn.Linear(vis_enc_dim, att_dim) # W_v
        self.dec_att = nn.Linear(sent_hidden_dim, att_dim) # W_v,h
        self.tanh = nn.Tanh()
        self.full_att = nn.Linear(att_dim, 1) # W_vatt
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, vis_enc_output, dec_hidden_state):
        vis_enc_output = vis_enc_output.permute(0, 2, 3, 1)
        vis_enc_output = vis_enc_output.view(vis_enc_output.size(0), -1, vis_enc_output.size(-1)) # (batch_size, 14*14, 512)
        vis_enc_att = self.enc_att(vis_enc_output) # (batch_size, num_pixels, att_dim)
        dec_output = self.dec_att(dec_hidden_state) # (batch_size, att_dim)

        join_output = self.tanh(vis_enc_att + dec_output.unsqueeze(1)) # (batch_size, num_pixels, att_dim)

        join_output = self.full_att(join_output).squeeze(2) # (batch_size, num_pixels)

        att_scores = self.softmax(join_output) # (batch_size, num_pixels)

        att_output = torch.sum(att_scores.unsqueeze(2) * vis_enc_output, dim = 1) # a_att

        return att_output, att_scores

class AttentionSemantic(nn.Module):
    def __init__(self, sem_enc_dim, sent_hidden_dim, att_dim):
        super(AttentionSemantic, self).__init__()

        self.enc_att = nn.Linear(sem_enc_dim, att_dim) # W_a
        self.dec_att = nn.Linear(sent_hidden_dim, att_dim) # W_a,h
        self.tanh = nn.Tanh()
        self.full_att = nn.Linear(att_dim, 1) # W_att
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, sem_enc_output, dec_hidden_state):
        sem_enc_output = self.enc_att(sem_enc_output) # (batch_size, no_of_tags, att_dim)
        dec_output = self.dec_att(dec_hidden_state) # (batch_size, att_dim)

        join_output = self.tanh(sem_enc_output + dec_output.unsqueeze(1)) # (batch_size, no_of_tags, att_dim)

        join_output = self.full_att(join_output).squeeze(2) # (batch_size, no_of_tags)

        att_scores = self.softmax(join_output) # (batch_size, no_of_tags)

        att_output = torch.sum(att_scores.unsqueeze(2) * sem_enc_output, dim = 1) # v_att

        return att_output, att_scores


In [None]:
class selfAttentionVisual(nn.Module):
    def __init__(self, vis_enc_dim, sent_hidden_dim, att_dim,version='ver_all_bn',momentum=0.1):
        super(AttentionVisual, self).__init__()
        
        self.version = version
        
        self.enc_att = nn.Linear(vis_enc_dim, vis_enc_dim) # W_v
        self.bn_enc_att = nn.BatchNorm1d(num_features=vis_enc_dim,momentum=momentum)
        
        self.dec_att = nn.Linear(sent_hidden_dim, vis_enc_dim) # W_v,h
        self.bn_dec_att = nn.BatchNorm1d(num_features=vis_enc_dim,momentum=momentum)
        
        self.tanh = nn.Tanh()
        
        self.full_att = nn.Linear(vis_enc_dim, vis_enc_dim) # W_vatt
        self.bn_full_att = nn.BatchNorm1d(num_features=vis_enc_dim,momentum=momentum)
        
        self.softmax = nn.Softmax()

    def forward(self, vis_enc_output, dec_hidden_state):

        if self.version == 'ver_all_bn':
            att_scores = self.ver_all_bn(vis_enc_output,dec_hidden_state)
            #att_output = torch.sum(att_scores.unsqueeze(2) * vis_enc_output, dim = 1) # a_att
            att_output = torch.mul(vis_enc_output,att_scores)
        elif self.version == 'ver_no_bn':
            att_scores = self.ver_no_bn(vis_enc_output,dec_hidden_state)
            att_output = torch.mul(vis_enc_output,att_scores) # a_att

        return att_output, att_scores
    
    def ver_all_bn(self,vis_enc_output,dec_hidden_state):
        #vis_enc_output = vis_enc_output.permute(0, 2, 3, 1)
        #vis_enc_output = vis_enc_output.view(vis_enc_output.size(0), -1, vis_enc_output.size(-1)) # (batch_size, 14*14, 512)
        
        #print(vis_enc_output.shape)
        
        vis_enc_att = self.bn_enc_att(self.enc_att(vis_enc_output)) # (batch_size, num_pixels, att_dim)
        dec_output = self.bn_dec_att(self.dec_att(dec_hidden_state.squeeze(1))) # (batch_size, att_dim)
        
        join_output = self.tanh(vis_enc_att+dec_output)
        join_output = self.bn_full_att(self.full_att(join_output))
        
        vis_alpha = self.softmax(join_output)
        
        return vis_alpha

    def ver_no_bn(self,vis_enc_output,dec_hidden_state):
        #vis_enc_output = vis_enc_output.permute(0, 2, 3, 1)
        #vis_enc_output = vis_enc_output.view(vis_enc_output.size(0), -1, vis_enc_output.size(-1)) # (batch_size, 14*14, 512)
        vis_enc_att = self.enc_att(vis_enc_output) # (batch_size, num_pixels, att_dim)
        dec_output = self.dec_att(dec_hidden_state.squeeze(1)) # (batch_size, att_dim)
        
        join_output = self.tanh(vis_enc_att+dec_output)
        join_output = self.full_att(join_output)
        
        vis_alpha = self.softmax(join_output)
        
        return vis_alpha
    
    

class selfAttentionSemantic(nn.Module):
    def __init__(self, sem_enc_dim, sent_hidden_dim, att_dim,version='ver_self',k=10,momentum=0.1):
        super(AttentionSemantic, self).__init__()
        
        self.version = version
        
        self.enc_att = nn.Linear(sent_hidden_dim, sent_hidden_dim) # W_a
        self.bn_enc_att = nn.BatchNorm1d(num_features=k,momentum=momentum)
        
        self.dec_att = nn.Linear(sent_hidden_dim, sent_hidden_dim) # W_a,h
        self.bn_dec_att = nn.BatchNorm1d(num_features=1,momentum=momentum)
        
        self.tanh = nn.Tanh()
        
        self.full_att = nn.Linear(sent_hidden_dim, sent_hidden_dim) # W_att
        self.bn_full_att =nn.BatchNorm1d(num_features=k,momentum=momentum)
        
        self.softmax = nn.Softmax()
        
        self.SelfAttention = Attention(dim=sem_enc_dim,dropout=0.2)

    def forward(self, sem_enc_output, dec_hidden_state):
        if self.version=='ver_all_bn':
            att_scores = self.ver_all_bn(sem_enc_output, dec_hidden_state) # (batch_size, no_of_tags)
            att_output = torch.mul(att_scores, sem_enc_output).sum(1) # sem_att
        elif self.version=='ver_no_bn':
            att_scores = self.ver_no_bn(sem_enc_output, dec_hidden_state)
            att_output = torch.mul(att_scores, sem_enc_output).sum(1)# sem_att
        elif self.version=='ver_self':
            att_scores, att_output = self.ver_self(sem_enc_output)

        return att_output, att_scores
    
    def ver_all_bn(self, sem_enc_output, dec_hidden_state):
        
        sem_enc_output = self.bn_enc_att(self.enc_att(sem_enc_output)) # (batch_size, no_of_tags, att_dim)
        dec_output = self.bn_dec_att(self.dec_att(dec_hidden_state.unsqueeze(1))) # (batch_size, att_dim)

        join_output = self.tanh(torch.add(sem_enc_output,dec_output))
        join_output = self.bn_full_att(self.full_att(join_output))
        sem_alpha = self.softmax(join_output)
        
        return sem_alpha
    
    def ver_no_bn(self,sem_enc_output,dec_hidden_state):
        sem_enc_output = self.enc_att(sem_enc_output) # (batch_size, no_of_tags, att_dim)
        dec_output = self.dec_att(dec_hidden_state.unsqueeze(1)) # (batch_size, att_dim)
        
        join_output = self.tanh(sem_enc_output+dec_output)
        join_output = self.full_att(join_output)
        sem_alpha = self.softmax(join_output)
        
        return sem_alpha
    
    def ver_self(self, sem_enc_output):
        return self.SelfAttention(sem_enc_output)

In [10]:
class SentenceLSTM(nn.Module):
    # def __init__(self, vis_embed_dim, sent_hidden_dim, att_dim, sent_input_dim, word_input_dim, int_stop_dim):
    def __init__(self, vis_embed_dim, sem_embed_dim, sent_hidden_dim, att_dim, sent_input_dim, word_input_dim, int_stop_dim, version=0):
        super(SentenceLSTM, self).__init__()
        self.version = version
        self.vis_att = AttentionVisual(vis_embed_dim, sent_hidden_dim, att_dim)
        self.sem_att = AttentionSemantic(sem_embed_dim, sent_hidden_dim, att_dim)
#         self.vis_att = selfAttentionVisual(vis_embed_dim, sent_hidden_dim, att_dim) # for self attention
#         self.sem_att = selfAttentionSemantic(sem_embed_dim, sent_hidden_dim, att_dim)
        
        self.contextLayer = nn.Linear(vis_embed_dim + sem_embed_dim, sent_input_dim) # W_fc
        self.contextLayer1 = nn.Linear(vis_embed_dim, sent_input_dim) # only visual feature
        self.lstm = nn.LSTMCell(sent_input_dim, sent_hidden_dim, bias=True) # LSTMCell

        self.sent_hidden_dim = sent_hidden_dim 
        self.word_input_dim = word_input_dim 

        self.topic_hid_layer = nn.Linear(sent_hidden_dim, word_input_dim) 
        self.topic_context_layer = nn.Linear(sent_input_dim, word_input_dim)
        self.tanh1 = nn.Tanh()

        self.stop_prev_hid = nn.Linear(sent_hidden_dim, int_stop_dim)
        self.stop_cur_hid = nn.Linear(sent_hidden_dim, int_stop_dim)
        self.tanh2 = nn.Tanh()
        self.final_stop_layer = nn.Linear(int_stop_dim, 2)

    def forward(self, vis_enc_output, tags, device):
        if self.version == 0:
            return self.v0(vis_enc_output, tags, device)
        elif self.version == 1:
            return self.v1(vis_enc_output, tags, device)
    
    def v0(self, vis_enc_output, tags, device):
        batch_size = vis_enc_output.shape[0]
        vis_enc_dim = vis_enc_output.shape[-1]
        sem_enc_dim = tags.shape[-1]

        sem_enc_ouput = tags.view(batch_size, -1, sem_enc_dim) 
        
        h = torch.zeros((batch_size, self.sent_hidden_dim)).to(device)
        c = torch.zeros((batch_size, self.sent_hidden_dim)).to(device)
        
        topics = torch.zeros((batch_size, tags.shape[1], self.word_input_dim)).to(device) # topics矩阵 (batch_size, k, word_input_dim)
        ps = torch.zeros((batch_size, tags.shape[1], 2)).to(device)
        for t in range(tags.shape[1]):
            vis_att_output, vis_att_scores = self.vis_att(vis_enc_output, h) # (batch_size, vis_enc_dim), (batch_size, num_pixels)
            sem_att_output, sem_att_scores = self.sem_att(sem_enc_ouput, h) # 
            
#             vis_att_output, vis_att_scores = self.vis_att(vis_enc_output, h) # for self attention (batch_size, vis_enc_dim), (batch_size, num_pixels)
#             sem_att_output, sem_att_scores = self.sem_att(tags, h) #

            context_output = self.contextLayer(torch.cat([vis_att_output, sem_att_output], dim = 1)) # (batch_size, sent_input_dim)

            h_prev = h.clone()

            h, c = self.lstm(context_output, (h, c)) # (batch_size, sent_hidden_dim), (batch_size, sent_hidden_dim)

            topic = self.tanh1(self.topic_hid_layer(h) + self.topic_context_layer(context_output)) # (batch_size, word_input_dim)

            p = self.tanh2(self.stop_prev_hid(h_prev) + self.stop_cur_hid(h)) # (batch_size, int_stop_dim)
            p = self.final_stop_layer(p) # (batch_size, 2)
            p = torch.softmax(p, 1)
            topics[:, t, :] = topic
            ps[:, t, :] = p
        return topics, ps
    
    def v1(self, vis_enc_output, tags, device):
        # visual feature only
        batch_size = vis_enc_output.shape[0]
        vis_enc_dim = vis_enc_output.shape[-1]

        h = torch.zeros((batch_size, self.sent_hidden_dim)).to(device)
        c = torch.zeros((batch_size, self.sent_hidden_dim)).to(device)
        
        topics = torch.zeros((batch_size, tags.shape[1], self.word_input_dim)).to(device) # topics矩阵 (batch_size, k, word_input_dim)
        ps = torch.zeros((batch_size, tags.shape[1], 2)).to(device)

        for t in range(tags.shape[1]):
            vis_att_output, vis_att_scores = self.vis_att(vis_enc_output, h) # (batch_size, vis_enc_dim), (batch_size, num_pixels)

            context_output = self.contextLayer1(vis_att_output) # (batch_size, sent_input_dim)

            h_prev = h.clone()

            h, c = self.lstm(context_output, (h, c)) # (batch_size, sent_hidden_dim), (batch_size, sent_hidden_dim)

            topic = self.tanh1(self.topic_hid_layer(h) + self.topic_context_layer(context_output)) # (batch_size, word_input_dim)

            p = self.tanh2(self.stop_prev_hid(h_prev) + self.stop_cur_hid(h)) # (batch_size, int_stop_dim)
            p = self.final_stop_layer(p) # (batch_size, 2)

            topics[:, t, :] = topic
            ps[:, t, :] = p
        return topics, ps


class WordLSTM(nn.Module):
    def __init__(self, word_input_dim, word_hidden_dim, vocab_size, num_layers = 1):
        super(WordLSTM, self).__init__()
        self.word_hidden_dim = word_hidden_dim
        self.word_input_dim = word_input_dim
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, word_input_dim)
        self.lstm = nn.LSTMCell(word_input_dim, word_hidden_dim, bias=True)
        self.fc = nn.Linear(word_hidden_dim, vocab_size)

    def forward(self, topic, caption, device): # teacher forcing training
        
        # topic: batch_size, 512
        # caption: batch_size, max_sent_len
        # 要生成outputs: batch_size, max_sent_len, vocab_size
        outputs = torch.zeros((caption.shape[0], caption.shape[1], self.vocab_size)).to(device)
        embeddings = self.embedding(caption).to(device) # (batch_size, max_sent_len, word_input_dim)
        h = topic.to(device)
        c = torch.zeros((caption.shape[0], self.word_hidden_dim)).to(device)
        for i in range(caption.shape[1]): # 对一句话中的每一个词
            x = embeddings[:, i, :] # (batch_size, word_input_dim)
            h, c = self.lstm(x, (h, c))
            output = self.fc(h) # (batch_size, vocab_size)
            outputs[:, i, :] = output
        return outputs[:, :-1, :]
        
        
    
    def val(self, topic, max_sent_len, device):
        outputs = torch.zeros((topic.shape[0], max_sent_len, self.vocab_size)).to(device)
        start = torch.tensor([[1] for _ in range(topic.shape[0])]).to(device)
        x = self.embedding(start).to(device) # (batch_size, max_sent_len, word_input_dim)
        h = topic.to(device)
        c = torch.zeros((topic.shape[0], self.word_hidden_dim)).to(device)
        for i in range(max_sent_len): # 对一句话中的每一个词
            x = x[:, 0, :] # (batch_size, word_input_dim)
            h, c = self.lstm(x, (h, c))
            output = self.fc(h) # (batch_size, vocab_size)
            outputs[:, i, :] = output
            values, indices = torch.topk(output, 1)
            x = self.embedding(indices).to(device)
        return outputs
            

In [44]:
def script(root_dir, train_tsv_path, val_tsv_path, image_path, img_size=224, crop_size=224, batch_size=8, shuffle=True, num_workers=0,
            enc_dim=512, sent_hidden_dim=512, att_dim=512, sent_input_dim=512, word_input_dim=512, int_stop_dim=512,
            word_hidden_dim=512, num_layers=1,
            learning_rate_cnn=1e-5, learning_rate_word=5e-4, learning_rate_sent=1e-5, weight_decay_cnn=1e-4, weight_decay_sent=1e-4, weight_decay_word=1e-4,
            num_epochs=200, lambda_sent=1.0, lambda_word=1.0, lambda_mlc=1.0, log_step=100, save_step=50, extend=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([ 
        transforms.Resize(img_size),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor()
    ])

    train_loader, vocab = get_loader(
        root_dir, 
        train_tsv_path, 
        image_path, 
        transform, 
        batch_size, 
        shuffle, 
        num_workers)

    vocab_size = len(vocab)
    print("vocab_size: ", vocab_size)

    val_loader, _ = get_loader(
        root_dir,
        val_tsv_path,
        image_path,
        transform,
        batch_size,
        shuffle,
        num_workers,
        vocab)
    
    encoderCNN = EncoderCNN('vgg19', pretrained=True).to(device)
    if extend:
        mlc = MLC(fc_in_features=encoderCNN.out_features*2).to(device)
    else:
        mlc = MLC(fc_in_features=encoderCNN.out_features).to(device)
    sentLSTM = SentenceLSTM(encoderCNN.out_features, mlc.sem_dim, sent_hidden_dim=512, att_dim=512, sent_input_dim=512, word_input_dim=512, int_stop_dim=512).to(device)
    wordLSTM = WordLSTM(word_input_dim, word_hidden_dim, vocab_size, num_layers).to(device)
    
    encoderCNN.load_state_dict(torch.load('../input/models/encoderCNN.ckpt'))
    mlc.load_state_dict(torch.load('../input/models/mlc.ckpt'))
    sentLSTM.load_state_dict(torch.load('../input/models/sentLSTM.ckpt'))
    wordLSTM.load_state_dict(torch.load('../input/models/wordLSTM.ckpt'))
    
#     criterion_stop = nn.CrossEntropyLoss().to(device)
#     criterion_words = nn.CrossEntropyLoss().to(device)

#     params_cnn = list(encoderCNN.parameters()) + list(mlc.parameters())
#     params_sent = list(sentLSTM.parameters()) 
#     params_word = list(wordLSTM.parameters())
#     optim_cnn = torch.optim.Adam(params = params_cnn, lr=learning_rate_cnn, weight_decay=weight_decay_cnn)
#     optim_sent = torch.optim.Adam(params = params_sent, lr=learning_rate_sent, weight_decay=weight_decay_sent)
#     optim_word = torch.optim.Adam(params = params_word, lr=learning_rate_word, weight_decay=weight_decay_word)


#     total_step = len(train_loader)
    
    
#     for epoch in range(num_epochs):
#         encoderCNN.train()
#         mlc.train()
#         sentLSTM.train()
#         wordLSTM.train()
#         if not extend:
#             for i, (images, tags, captions, prob) in enumerate(train_loader):
#                 optim_cnn.zero_grad()
#                 optim_word.zero_grad()
#                 optim_sent.zero_grad()
#                 batch_size = images.shape[0]
#                 images = images.to(device)
#                 captions = captions.to(device)
#                 prob = prob.to(device)
#                 tags = tags.to(device)

#                 vis_enc_output, avg_enc_output = encoderCNN(images)
#                 pred_tags, semantic_features = mlc(avg_enc_output)

#     #             log_prob = torch.nn.functional.log_softmax(pred_tags, dim=1).to(device)
#     #             loss_func_mse = nn.MSELoss().to(device)
#                 loss_func_ce = nn.BCELoss().to(device)
#                 loss_mlc = loss_func_ce(pred_tags, tags.to(torch.float))
#                 # semantic_features = torch.randn((batch_size, 10, 512))
#                 topics, ps = sentLSTM(vis_enc_output, semantic_features, device)

#                 loss_word = torch.tensor([0.0]).to(device)
#                 for j in range(captions.shape[1]): # teacher forcing了
#                     word_outputs = wordLSTM(topics[:, j, :], captions[:, j, :], device) # (batch_size, max_sent_len, vocab_size)
#                     loss_word += criterion_words(word_outputs.contiguous().view(-1, vocab_size), captions[:, j, 1:].contiguous().view(-1))
#                 ps = torch.log(ps)
#                 loss_sent = nn.NLLLoss()(ps.view(-1, 2), prob.view(-1))
#                 loss = lambda_sent * loss_sent + lambda_word * loss_word  + lambda_mlc * loss_mlc
#     #             loss = loss_mlc
#                 loss.backward()
#                 optim_cnn.step()
#                 optim_word.step()
#                 optim_sent.step()


#                 if i % log_step == 0:
#                     print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
#                           .format(epoch, num_epochs, i, total_step, loss.item())) 
#                     print('sent_loss:{}, word_loss:{}, tag_loss:{}'
#                           .format(loss_sent, loss_word.item(), loss_mlc))


#             # 保存模型参数
#             if (epoch+1) % save_step == 0: 
#                 torch.save(encoderCNN.state_dict(), 'encoderCNN'+str(epoch+63)+'.ckpt')
#                 torch.save(mlc.state_dict(), 'mlc'+str(epoch+63)+'.ckpt')
#                 torch.save(sentLSTM.state_dict(), 'sentLSTM'+str(epoch+63)+'.ckpt')
#                 torch.save(wordLSTM.state_dict(), 'wordLSTM'+str(epoch+63)+'.ckpt')
#         else:
#             for i, (images1, images2, tags, captions, prob) in enumerate(train_loader):
#                 optim_cnn.zero_grad()
#                 optim_word.zero_grad()
#                 optim_sent.zero_grad()
                
#                 batch_size = images1.shape[0]
#                 images1 = images1.to(device)
#                 images2 = images2.to(device)
#                 captions = captions.to(device)
#                 prob = prob.to(device)
#                 tags = tags.to(device)

#                 vis_enc_output_1, avg_enc_output_1 = encoderCNN(images1)
#                 vis_enc_output_2, avg_enc_output_2 = encoderCNN(images2)
#                 vis_enc_output = torch.cat((vis_enc_output_1, vis_enc_output_2), 2)
#                 avg_enc_output = torch.cat((avg_enc_output_1, avg_enc_output_2), 1)
#                 pred_tags, semantic_features = mlc(avg_enc_output)

#     #             log_prob = torch.nn.functional.log_softmax(pred_tags, dim=1).to(device)
#     #             loss_func_mse = nn.MSELoss().to(device)
#                 loss_func_ce = nn.BCELoss().to(device)
#                 loss_mlc = loss_func_ce(pred_tags, tags.to(torch.float))
#                 # semantic_features = torch.randn((batch_size, 10, 512))
#                 topics, ps = sentLSTM(vis_enc_output, semantic_features, device)

#                 loss_word = torch.tensor([0.0]).to(device)
#                 for j in range(captions.shape[1]): # teacher forcing了
#                     word_outputs = wordLSTM(topics[:, j, :], captions[:, j, :], device) # (batch_size, max_sent_len, vocab_size)
#                     loss_word += criterion_words(word_outputs.contiguous().view(-1, vocab_size), captions[:, j, 1:].contiguous().view(-1))
#                 ps = torch.log(ps)
#                 loss_sent = nn.NLLLoss()(ps.view(-1, 2), prob.view(-1))
#                 loss = lambda_sent * loss_sent + lambda_word * loss_word  + lambda_mlc * loss_mlc
#     #             loss = loss_mlc
#                 loss.backward()
#                 optim_cnn.step()
#                 optim_word.step()
#                 optim_sent.step()


#                 if i % log_step == 0:
#                     print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
#                           .format(epoch, num_epochs, i, total_step, loss.item())) 
#                     print('sent_loss:{}, word_loss:{}, tag_loss:{}'
#                           .format(loss_sent, loss_word.item(), loss_mlc))


#             # 保存模型参数
#             if (epoch+1) % save_step == 0: 
#                 torch.save(encoderCNN.state_dict(), 'encoderCNN'+str(epoch)+'_2.ckpt')
#                 torch.save(mlc.state_dict(), 'mlc'+str(epoch)+'_2.ckpt')
#                 torch.save(sentLSTM.state_dict(), 'sentLSTM'+str(epoch)+'_2.ckpt')
#                 torch.save(wordLSTM.state_dict(), 'wordLSTM'+str(epoch)+'_2.ckpt')
            
    return evaluate(val_loader, encoderCNN, mlc, sentLSTM, wordLSTM, vocab, device, extend)

In [12]:
def evaluate(val_loader, encoderCNN, mlc, sentLSTM, wordLSTM, vocab, device, extend=False):
    encoderCNN.eval()
    sentLSTM.eval()
    wordLSTM.eval()
    mlc.eval()

    vocab_size = len(vocab)

#     criterion_stop_val = nn.CrossEntropyLoss().to(device)
#     criterion_words_val = nn.CrossEntropyLoss().to(device)

    references = list()
    hypotheses = list()
    
    if not extend:
        for i, (images, tags, captions, prob) in enumerate(val_loader):
            images = images.to(device)
            captions = captions.to(device)
            prob = prob.to(device)
            tags = tags.to(device)


            vis_enc_output, avg_enc_output = encoderCNN(images)
            pred_tags, semantic_features = mlc(avg_enc_output)
    #         loss_func_ce = nn.BCELoss().to(device)
    #         loss_mlc = loss_func_ce(pred_tags, tags.to(torch.float))
            topics, ps = sentLSTM(vis_enc_output, semantic_features, device)


    #         loss_word = torch.tensor([0.0]).to(device)
            pred_words = torch.zeros((captions.shape[0], captions.shape[1], captions.shape[2]-1)) # (batch_size, sent_num, max_sent_len)
            for j in range(captions.shape[1]):
                # word_outputs = wordLSTM(topics[:, j, :], captions[:, j, :], device)
                word_outputs = wordLSTM.val(topics[:, j, :], captions.shape[2]-1, device)
    #             loss_word += criterion_words_val(word_outputs.contiguous().view(-1, vocab_size), captions[:, j, 1:].contiguous().view(-1))
                _, words = torch.max(word_outputs, 2)
                pred_words[:, j, :] = words
    #         loss_sent = criterion_stop_val(ps.view(-1, 2), prob.view(-1))
    #         loss = loss_sent + loss_word # + loss_mlc
            # ps = nn.Softmax(dim=2)(ps)
            for j in range(captions.shape[0]):
                pred_caption = []
                target_caption = []
                pred_tag = []
                target_tag = []
                for t in range(tags.shape[1]):
                        if tags[j, t] == 1:
                            target_tag.append(index_tag[t])
                        if pred_tags[j, t] >= 0.05:
                            pred_tag.append(index_tag[t])
                print("real tags: {}".format("; ".join(target_tag)))
                print("pred tags: {}".format("; ".join(pred_tag)))
                for k in range(captions.shape[1]):
                    if float(ps[j, k, 1]) > 0.5:
                        words_x = pred_words[j, k, :].tolist()
                        cap = " ".join([vocab.idx2word[w] for w in words_x if w not in {vocab.word2idx['<pad>'], vocab.word2idx['<start>'], vocab.word2idx['<end>']}]) + "."
                        if cap != ".":
                            pred_caption.append(cap)


                    if prob[j, k] == 1:
                        words_y = captions[j, k, :].tolist()
                        target_caption.append(" ".join([vocab.idx2word[w] for w in words_y if w not in {vocab.word2idx['<pad>'], vocab.word2idx['<start>'], vocab.word2idx['<end>']}]) + ".")
                print("real captions: {}".format(" ".join(target_caption)))
                print("pred captions: {}".format(" ".join(pred_caption)))
                hypotheses.append(pred_caption)
                references.append(target_caption)
                print()
    else:
        for i, (images1, images2, tags, captions, prob) in enumerate(val_loader):
            images1 = images1.to(device)
            images2 = images2.to(device)
            captions = captions.to(device)
            prob = prob.to(device)
            tags = tags.to(device)


            vis_enc_output_1, avg_enc_output_1 = encoderCNN(images1)
            vis_enc_output_2, avg_enc_output_2 = encoderCNN(images2)
            vis_enc_output = torch.cat((vis_enc_output_1, vis_enc_output_2), 2)
            avg_enc_output = torch.cat((avg_enc_output_1, avg_enc_output_2), 1)
            pred_tags, semantic_features = mlc(avg_enc_output)
    #         loss_func_ce = nn.BCELoss().to(device)
    #         loss_mlc = loss_func_ce(pred_tags, tags.to(torch.float))
            topics, ps = sentLSTM(vis_enc_output, semantic_features, device)


    #         loss_word = torch.tensor([0.0]).to(device)
            pred_words = torch.zeros((captions.shape[0], captions.shape[1], captions.shape[2]-1)) # (batch_size, sent_num, max_sent_len)
            for j in range(captions.shape[1]):
                # word_outputs = wordLSTM(topics[:, j, :], captions[:, j, :], device)
                word_outputs = wordLSTM.val(topics[:, j, :], captions.shape[2]-1, device)
    #             loss_word += criterion_words_val(word_outputs.contiguous().view(-1, vocab_size), captions[:, j, 1:].contiguous().view(-1))
                _, words = torch.max(word_outputs, 2)
                pred_words[:, j, :] = words
    #         loss_sent = criterion_stop_val(ps.view(-1, 2), prob.view(-1))
    #         loss = loss_sent + loss_word # + loss_mlc
            # ps = nn.Softmax(dim=2)(ps)
            for j in range(captions.shape[0]):
                pred_caption = []
                target_caption = []
                pred_tag = []
                target_tag = []
                for t in range(tags.shape[1]):
                        if tags[j, t] == 1:
                            target_tag.append(index_tag[t])
                        if pred_tags[j, t] >= 0.05:
                            pred_tag.append(index_tag[t])
                print("real tags: {}".format("; ".join(target_tag)))
                print("pred tags: {}".format("; ".join(pred_tag)))
                for k in range(captions.shape[1]):
                    if float(ps[j, k, 1]) > 0.5:
                        words_x = pred_words[j, k, :].tolist()
                        cap = " ".join([vocab.idx2word[w] for w in words_x if w not in {vocab.word2idx['<pad>'], vocab.word2idx['<start>'], vocab.word2idx['<end>']}]) + "."
                        if cap != ".":
                            pred_caption.append(cap)


                    if prob[j, k] == 1:
                        words_y = captions[j, k, :].tolist()
                        target_caption.append(" ".join([vocab.idx2word[w] for w in words_y if w not in {vocab.word2idx['<pad>'], vocab.word2idx['<start>'], vocab.word2idx['<end>']}]) + ".")
                print("real captions: {}".format(" ".join(target_caption)))
                print("pred captions: {}".format(" ".join(pred_caption)))
                hypotheses.append(pred_caption)
                references.append(target_caption)
                print()
        

    assert len(references) == len(hypotheses)
    return hypotheses, references

In [45]:
hypotheses, references = script('', './train.tsv', './test.tsv', '../input/chest-xrays-indiana-university/images/images_normalized', num_epochs=10)

In [14]:
import shutil
shutil.copytree(r'../input/cococaptions/cococaption', r'./cococaption')

In [15]:
from cococaption.pycocotools.coco import COCO
from cococaption.pycocoevalcap.eval import COCOEvalCap
import json

def evalscores(hypotheses, references):
    targ_annotations = list()
    res_annotations = list()
    img_annotations = list()
    coco_ann_file = 'coco.json'
    res_ann_file = 'res.json'

    for i in range(len(hypotheses)):
        targ_anno_dict = {"image_id": i,
                          "id": i,
                          "caption": " ".join(references[i])}

        targ_annotations.append(targ_anno_dict)

        res_anno_dict = {"image_id": i,
                         "id": i,
                         "caption": " ".join(hypotheses[i])}

        res_annotations.append(res_anno_dict)

        image_anno_dict = {"id": i,
                           "file_name": i}

        img_annotations.append(image_anno_dict)

    coco_dict = {"type": 'captions', 
                 "images": img_annotations, 
                 "annotations": targ_annotations}

    res_dict = {"type": 'captions', 
                "images": img_annotations, 
                "annotations": res_annotations}

    with open(coco_ann_file, 'w') as fp:
        json.dump(coco_dict, fp)

    with open(res_ann_file, 'w') as fs:
        json.dump(res_annotations, fs)

    coco = COCO(coco_ann_file)
    cocoRes = coco.loadRes(res_ann_file)

    cocoEval = COCOEvalCap(coco, cocoRes)

    cocoEval.evaluate()

    for metric, score in cocoEval.eval.items():
        print('%s: %.3f'%(metric, score))

In [46]:
evalscores(hypotheses, references)

In [42]:
def generate(image1, image2, vocab):   
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([ 
        transforms.Resize(224),
        transforms.RandomCrop(224),
        transforms.ToTensor()
    ])
    image1 = Image.open(image1)
    image1 = transform(image1).to(device)
    image1 = image1.unsqueeze(0)
    
    if image2:
        image2 = Image.open(image2)
        image2 = transform(image2).to(device)
        image2 = image2.unsqueeze(0)
    
    vocab_size = len(vocab)
    encoderCNN = EncoderCNN('vgg19').to(device)
    mlc = MLC(fc_in_features=encoderCNN.out_features).to(device)
    sentLSTM = SentenceLSTM(encoderCNN.out_features, mlc.sem_dim, sent_hidden_dim=512, att_dim=512, sent_input_dim=512, word_input_dim=512, int_stop_dim=512).to(device)
    wordLSTM = WordLSTM(word_input_dim=512, word_hidden_dim=512, vocab_size=vocab_size, num_layers=1).to(device)
    
    encoderCNN.load_state_dict(torch.load('../input/models/encoderCNN112.ckpt'))
    mlc.load_state_dict(torch.load('../input/models/mlc112.ckpt'))
    sentLSTM.load_state_dict(torch.load('../input/models/sentLSTM112.ckpt'))
    wordLSTM.load_state_dict(torch.load('../input/models/wordLSTM112.ckpt'))
    
    vis_enc_output_1, avg_enc_output_1 = encoderCNN(image1)
#     if not image2:
    vis_enc_output = vis_enc_output_1
    avg_enc_output = avg_enc_output_1
#     else:
#     vis_enc_output_2, avg_enc_output_2 = encoderCNN(image2)
#     vis_enc_output = torch.cat((vis_enc_output_1, vis_enc_output_2), 2)
#     avg_enc_output = torch.cat((avg_enc_output_1, avg_enc_output_2), 0)
        
    pred_tags, semantic_features = mlc(avg_enc_output)
    topics, ps = sentLSTM(vis_enc_output, semantic_features, device)
    
    pred_words = []
    for j in range(topics.shape[1]):
        if ps[0, j, 1] > 0.5:
            word_outputs = wordLSTM.val(topics[:, j, :], 30, device)
            _, words = torch.max(word_outputs, 2)
            pred_words.append(words)
        else:
            break
    pred_caption = []
    target_caption = []
    pred_tag = []
    target_tag = []
    for t in range(pred_tags.shape[0]):
        if pred_tags[t] >= 0.05:
            pred_tag.append(index_tag[t])
    print("pred tags: {}".format("; ".join(pred_tag)))
    for k in range(len(pred_words)):
            words_x = pred_words[k][0, :].tolist()
            cap = " ".join([vocab.idx2word[w] for w in words_x if w not in {vocab.word2idx['<pad>'], vocab.word2idx['<start>'], vocab.word2idx['<end>']}]) + "."
            if cap != ".":
                pred_caption.append(cap)
    print("pred captions: {}".format(" ".join(pred_caption)))
    
    


In [33]:
vocab = build_vocab(create_captions('./train.tsv'), 3)

In [43]:
generate("../input/chest-xrays-indiana-university/images/images_normalized/1600_IM-0390-1001.dcm.png", None, vocab)