In [1]:
import os
import torch
import torch.nn.functional as F

import numpy as np
import argparse
import torch.nn as nn
import random
import cv2
import matplotlib.pyplot as plt
from skimage.transform import resize

from tqdm import trange
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable

from dataset_utility import raven_tsne, ToTensor
from model.RANet_attmap import RANet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = argparse.Namespace(
    model_name = 'RANet',
    batch_size = 32,
    root = '../dataset/',
    dataset = 'IRAVEN',
    fig_type = 'distribute_nine',
    img_size = 160,
    workers = 4,
    save = './results/checkpoint/',
    train_mode = 0,
    train_once = False,
    seed = 12345,
)

In [3]:
if args.fig_type == 'all':
    args.train_once = True
else:
    args.train_once = False

if args.dataset == 'PGM':
    args.train_once = False

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'device: {device}')
# torch.manual_seed(args.seed)
# if torch.cuda.is_available:
#     torch.cuda.manual_seed(args.seed)
# np.random.seed(args.seed)
# random.seed(args.seed)

device: cuda


In [5]:
model = RANet()
model.to(device)

RANett(
  (scattering): Scattering()
  (conv): CNNModule(
    (features): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (4): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride

In [6]:
save_name = args.model_name + '_' + args.dataset + '_' + args.fig_type

save_path_model = os.path.join(args.save, save_name)

tf = transforms.Compose([ToTensor()])   

In [7]:
train_set = raven_tsne(os.path.join(args.root, args.dataset), 'train', args.fig_type, args.img_size, tf)

test_set = raven_tsne(os.path.join(args.root, args.dataset), 'test', args.fig_type, args.img_size, tf)

print('train length', len(train_set), args.fig_type)
print('test length', len(test_set), args.fig_type)

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

train length 6000 distribute_nine
test length 2000 distribute_nine


In [8]:
def transparent_cmap(cmap, N=255):
    "Copy colormap and set alpha values"

    mycmap = cmap
    mycmap._init()
    mycmap._lut[:,-1] = np.linspace(0, 0.8, N+4)
    return mycmap

In [9]:
def vis(rows):
    batch_size, n_rows, n_panels, _, panel_h, panel_w = rows.shape

    for i in range(batch_size):
        fig, axs = plt.subplots(n_rows, n_panels, figsize=(6,14))

        for j in range(n_rows):
            for k in range(n_panels):
                img = rows[i,j,k].squeeze().cpu().numpy()

                axs[j, k].imshow(img, cmap='gray', aspect='auto')
                axs[j, k].axis('off')

        plt.show()

In [10]:
def visualize_attention(rows, att_maps):
    batch_size, n_rows, n_panels, _, panel_h, panel_w = rows.shape
    _, _, _, _, att_h, att_w = att_maps.shape

    cmap = plt.cm.Reds
    t_cmap = transparent_cmap(cmap)
    #print(att_maps[0][0][0][0])
    for i in range(batch_size):
        fig, axs = plt.subplots(n_rows, n_panels, figsize=(6, 14))  # Adjusted subplot arrangement
        #print(att_maps.shape)
        for j in range(n_rows):
            for k in range(n_panels):
                img = rows[i, j, k].squeeze().cpu().numpy()
                att_map = att_maps[i, j, k].squeeze().cpu().numpy()
                att_map_resized = cv2.resize(att_map, (panel_h, panel_w), interpolation=cv2.INTER_NEAREST)
                att_map_resized = att_map_resized / float(att_map_resized.sum())
                
                axs[j, k].imshow(img, cmap='gray', aspect='auto')
                axs[j, k].imshow(att_map_resized, cmap=t_cmap, alpha=0.5, aspect='auto')  # overlay attention map
                axs[j, k].axis('off')

                if j>=2:
                    axs[j, k].text(panel_w, panel_h, str(j-2), verticalalignment='bottom', horizontalalignment='right', color='black', fontsize=15)
            
        plt.show()


In [11]:
def make_rows(questions, answers):
    row1 = questions[:,:3].unsqueeze(1)
    row2 = questions[:,3:6].unsqueeze(1)
    row3_p = questions[:,6:8].unsqueeze(1).repeat(1,8,1,1,1,1)

    candidates = answers.unsqueeze(2)
    row3 = torch.cat([row3_p, candidates], dim=2)
    rows = torch.cat([row1,row2,row3], dim=1)
    
    return rows

In [12]:
def meta_to_rule(meta_matrix):
    relations = ["Constant", "Progression", "Arithmetic", "Distribute_Three"]
    attributes = ["Number", "Position", "Type", "Size", "Color"]
    #print(meta_matrix)
    for row in meta_matrix:
        rule = []
        for i, value in enumerate(row):
            if value == 1:
                if i < 4:  # relation
                    rule.append(relations[i])
                else:  # attribute
                    rule.append(attributes[i-4])
        if rule:  # if rule is not empty
            print(" ".join(rule))


In [15]:
def save_image(fig, filename):
    fig.savefig(filename)
    plt.close(fig)

def plot_rows(rows, model):
    rows = rows.squeeze().cpu().numpy()
    fig, axs = plt.subplots(1, 3, figsize=(9, 3))
    for j in range(3):
        axs[j].imshow(rows[0, j], cmap='gray')
        axs[j].axis('off')
    plt.tight_layout()
    save_image(fig, 'rows_'+str(model)+'.png')

def plot_feat_map(feat_map, model):
    feat_map = feat_map.squeeze().cpu().numpy()
    fig, axs = plt.subplots(32, 3, figsize=(9, 100))
    for j in range(3):
        for k in range(32):
            resized_feat_map = cv2.resize(feat_map[0, j, k], (160, 160), interpolation=cv2.INTER_NEAREST)
            axs[k, j].imshow(resized_feat_map, cmap='gray')
            axs[k, j].axis('off')
    plt.tight_layout()
    save_image(fig, 'feat_map_'+str(model)+'.png')

def plot_att_map(att_map, model):
    att_map = att_map.squeeze().cpu().numpy()
    fig, axs = plt.subplots(32, 3, figsize=(9, 100))
    for j in range(3):
        for k in range(32):
            resized_att_map = cv2.resize(att_map[0, j, k], (160, 160), interpolation=cv2.INTER_NEAREST)
            axs[k, j].imshow(resized_att_map, cmap='gray')
            axs[k, j].axis('off')
    plt.tight_layout()
    save_image(fig, 'att_map_'+str(model)+'.png')

def plot_aug_feat(aug_feat, model):
    aug_feat = aug_feat.squeeze().cpu().numpy()
    fig, axs = plt.subplots(32, 3, figsize=(9, 100))
    for j in range(3):
        for k in range(32):
            resized_aug_feat = cv2.resize(aug_feat[0, j, k], (160, 160), interpolation=cv2.INTER_NEAREST)
            axs[k, j].imshow(resized_aug_feat, cmap='gray')
            axs[k, j].axis('off')
    plt.tight_layout()
    save_image(fig, 'aug_feat_'+str(model)+'.png')

In [16]:
if __name__ == '__main__':
    model.load_state_dict(torch.load('./model.pth'))
    model.eval()
    data_i = []
    for i in range(10):
        number = random.randint(0,1999)
        data_i.append(i)

    for i in data_i:
        sample_data = test_set[i]
        
        image, target, meta_target, meta_matrix = sample_data

        image = image.to(device).unsqueeze(0)
        target = target.to(device).unsqueeze(0)

        questions = image[:,:8]
        answers = image[:,8:]

        questions = torch.unsqueeze(questions, dim=2)
        answers = torch.unsqueeze(answers, dim=2)
        
        with torch.no_grad():
            predict, att_maps, feat_maps, aug_feat = model(image)
        pred = torch.max(predict, 1)[1]
        meta_to_rule(meta_matrix)
        
        print(f'Att target: {target.cpu()}, predict: {pred.cpu()}')
        rows = make_rows(questions, answers)
        
        plot_rows(rows, '6e-5')
        plot_feat_map(feat_maps, '6e-5')
        plot_att_map(att_maps, '6e-5')
        plot_aug_feat(aug_feat, '6e-5')
        

        

Progression Number
Distribute_Three Type
Distribute_Three Size
Arithmetic Color
Att target: tensor([6]), predict: tensor([6])
No Att target: tensor([6]), predict: tensor([6])
