# Top Down Attention

In [1]:
import os
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
import random

from dataloader import DatasetLoader
from model.attention_cnn import Net
from plot.loss import plot_loss
from plot.vis import visualize_attention_on_image

import sys

In [2]:
ex_name='w_attention'
apply_attention = True
save_dir = f'/home/kurita/GitHub/attention/output/{ex_name}'
os.makedirs(save_dir, exist_ok=True)
dataset_dir = '/home/kurita/GitHub/attention/data'


In [3]:
# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ハイパーパラメータの設定
num_epochs = 200
batch_size = 128
learning_rate = 0.001
momentum = 0.9

# CIFAR-10データセットの準備
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
#trainset = DatasetLoader(root=dataset_dir, phase='train', transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
#testset = DatasetLoader(root=dataset_dir, phase='test', transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



Using device: cuda


In [4]:
# モデルのインスタンス化
net = Net(apply_attention=apply_attention).to(device)

# 損失関数と最適化関数の定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)

# 学習の記録用リスト
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

# 学習ループ
for epoch in range(num_epochs):
    net.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        # inputs, labels = data['image'].to(device), data['label'].to(device)

        if (i == 0):
            print(f'Epoch {epoch+1}\n-------------------------------')
#        print('i=', i)

        optimizer.zero_grad()
#        print(inputs.shape, labels.shape)
        
        outputs = net(inputs)
#        print(outputs.shape)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_loss = running_loss / len(trainloader)
    train_accuracy = 100 * correct_train / total_train
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

    # テストフェーズ
    net.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for i, data in enumerate(testloader):
            images, labels = data[0].to(device), data[1].to(device)
            # images, labels = data['image'].to(device), data['label'].to(device)

            outputs = net(images) # モデルの出力
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_loss = test_loss / len(testloader)
    test_accuracy = 100 * correct_test / total_test
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
    print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

print('Finished Training')
plot_loss(os.path.join(save_dir, 'loss.png'), train_losses, test_losses, train_accuracies, test_accuracies)



Epoch 1
-------------------------------
Epoch [1/200], Train Loss: 2.2812, Train Accuracy: 16.76%
Epoch [1/200], Test Loss: 2.2313, Test Accuracy: 23.38%
Epoch 2
-------------------------------
Epoch [2/200], Train Loss: 2.1452, Train Accuracy: 23.84%
Epoch [2/200], Test Loss: 2.0246, Test Accuracy: 28.24%
Epoch 3
-------------------------------
Epoch [3/200], Train Loss: 1.9702, Train Accuracy: 29.33%
Epoch [3/200], Test Loss: 1.8455, Test Accuracy: 34.78%
Epoch 4
-------------------------------
Epoch [4/200], Train Loss: 1.8223, Train Accuracy: 34.14%
Epoch [4/200], Test Loss: 1.6780, Test Accuracy: 40.49%
Epoch 5
-------------------------------
Epoch [5/200], Train Loss: 1.6787, Train Accuracy: 38.47%
Epoch [5/200], Test Loss: 1.5419, Test Accuracy: 44.20%
Epoch 6
-------------------------------
Epoch [6/200], Train Loss: 1.5832, Train Accuracy: 41.39%
Epoch [6/200], Test Loss: 1.4855, Test Accuracy: 46.28%
Epoch 7
-------------------------------
Epoch [7/200], Train Loss: 1.5150, T

In [5]:
# fc1出力の256の特徴量を得る
# 中間層の特徴ベクトルの可視化 
net.eval()
all_features = []
all_labels = []
selected_batch_indices = random.sample(range(len(testloader)), 10)
with torch.no_grad():
    for i, data in enumerate(testloader):
        if i  in selected_batch_indices:  # 10バッチのみ使用 
            images, labels = data[0], data[1]
            # images, labels = data['image'], data['label']
            images = images.to(device)
            output, features = net(images, return_feat=True)
            features = features.cpu().numpy()
            all_features.append(features)
            all_labels.append(labels.cpu().numpy())

all_features = np.concatenate(all_features, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

pca = PCA(n_components=2)
principal_components = pca.fit_transform(all_features.reshape(all_features.shape[0], -1)) # 3D特徴マップを2Dに変換

plt.figure(figsize=(8, 6))
colors = plt.cm.get_cmap('tab10', len(classes)) # または plt.cm.get_cmap('Set1', len(classes)) など

for i in range(len(classes)):
    indices = np.where(all_labels == i)[0]
    plt.scatter(principal_components[indices, 0], principal_components[indices, 1],
                label=classes[i], alpha=0.6, c=[colors(i)])
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('PCA of Intermediate Layer Features')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'pca.png'))
plt.close()


if apply_attention:

    # テストデータからランダムなサンプルを選択
    num_samples_to_visualize = 5
    indices = random.sample(range(len(testset)), num_samples_to_visualize)
    for index in indices:
        img, label = testset[index][0], testset[index][1]
        # img, label = testset[index]['image'], testset[index]['label']
        original_img = img.clone() # 可視化用にオリジナルのテンソルを保持
        class_name = classes[label]
        print(f"Visualizing top-down attention for sample with label: {class_name}")
        save_path = os.path.join(save_dir, f'vis_{index}.png')
        visualize_attention_on_image(device, net, img, original_img, classes, save_path)

  colors = plt.cm.get_cmap('tab10', len(classes)) # または plt.cm.get_cmap('Set1', len(classes)) など


Visualizing top-down attention for sample with label: horse
Visualizing top-down attention for sample with label: horse
Visualizing top-down attention for sample with label: frog
Visualizing top-down attention for sample with label: cat
Visualizing top-down attention for sample with label: ship
