In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.patches as mpatches
import umap.umap_ as umap

from torch.utils.data import random_split
from torch_geometric.loader import DataLoader as GeoDataLoader

from models import *
from utils.utils_preprocess import *

from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import mlflow

from utils.utils_train_model import *
from utils.utils_preprocess import *
from models import *

mlflow.set_tracking_uri(uri="http://127.0.0.1:5000")

device = (
    "mps" 
    if torch.backends.mps.is_available() 
    else "cuda" 
    if torch.cuda.is_available() 
    else "cpu"
)
print(f'Using device: {device}')

data_dict_path = './output_data/data_dict.pth'
data_dict = torch.load(data_dict_path)

print("Processing data")
print([len(data_dict[f'{i}']) for i in range(1, 8)])

data_list = sum([data_dict[str(i)] for i in range(1, 8)], [])
total_size = len(data_list)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_data, val_data = random_split(data_list, [train_size, val_size])

# Step 2: Create DataLoader for training and validation
batch_size = 64
train_loader = GeoDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = GeoDataLoader(val_data, batch_size=batch_size, shuffle=False)

# Step 3: Model setup
# Some of the datas don't have adj matrix
num_node_features = next(data.x.shape[1] for data in data_list if data.x is not None)

hidden_channels = 64
model_dict = get_model_list(device, num_node_features, hidden_channels)
# model = Net_Alex(num_node_features, hidden_channels).to(device)
criterion = nn.CrossEntropyLoss()
num_epochs = 25

In [None]:
print('Start training ...')
for i, (model_name, model) in enumerate(model_dict.items()):
    params = {
    "model": model_name,
    "num_epochs": 25,
    "lr":0.001,
    }
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cumulative_preds = []
    cumulative_labels = []
    with mlflow.start_run():
        mlflow.log_params(params)
        for epoch in range(1, num_epochs+1):
            model.train()
            total_loss = 0
            correct = 0
            for batch in train_loader:
                batch  = batch.to(device)
                x = batch.x
                edge_index = batch.edge_index
                edge_weight = batch.edge_weight
                image_features = batch.image_features
                batch_y = batch.y

                optimizer.zero_grad()
                
                if i % 2 == 0:
                    out = model(image_features)
                    loss = criterion(out, batch_y)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                    _, pred = out.max(dim=1)
                    correct += (pred == batch_y).sum().item()
                else:
                    GNN_output, CNN_output = model(x, edge_index, edge_weight, batch.batch, image_features)
                    if GNN_output is not None:
                        loss_GNN = criterion(GNN_output, batch_y)
                        loss_AlexNet = criterion(CNN_output, batch_y)
                        loss = (loss_GNN + loss_AlexNet) / 2
                        loss.backward()
                        _, pred = GNN_output.max(dim=1)
                        correct += (pred == batch_y).sum().item()
                    else:
                        loss = criterion(CNN_output, batch_y)
                        loss.backward()
                        _, pred = CNN_output.max(dim=1)
                        correct += (pred == batch_y).sum().item()
                    optimizer.step()
                    total_loss += loss.item()
            train_loss = total_loss / len(train_loader)
            train_accuracy = correct / len(train_data)

            model.eval()
            val_loss = 0
            val_correct = 0
            epoch_preds = []
            epoch_labels = []
        
            with torch.no_grad():
                for batch in val_loader:
                    batch = batch.to(device)
                    x = batch.x
                    edge_index = batch.edge_index
                    edge_weight = batch.edge_weight
                    image_features = batch.image_features
                    batch_y = batch.y

                    if i % 2 == 0:
                        out = model(image_features)
                        loss = criterion(out, batch_y)
                        val_loss += loss.item()

                        _, pred = out.max(dim=1)
                        val_correct += (pred == batch_y).sum().item()
                    else:
                        GNN_output, AlexNet_output = model(x, edge_index, edge_weight, batch.batch, image_features)
                        if GNN_output is not None:
                            loss_GNN = criterion(GNN_output, batch_y)
                            loss_AlexNet = criterion(AlexNet_output, batch_y)
                            loss = (loss_GNN + loss_AlexNet) / 2
                            _, pred = GNN_output.max(dim=1)
                            val_correct += (pred == batch_y).sum().item()
                        else:
                            loss = criterion(AlexNet_output, batch_y)
                            _, pred = AlexNet_output.max(dim=1)
                            val_correct += (pred == batch_y).sum().item()
                        val_loss += loss.item()

                    epoch_preds.extend(pred.cpu().numpy())
                    epoch_labels.extend(batch_y.cpu().numpy())
            avg_val_loss = val_loss / len(val_loader)
            val_accuracy = val_correct / len(val_data)
            mlflow.log_metric('train_loss', train_loss, step=epoch)
            mlflow.log_metric('train_accuracy', train_accuracy, step=epoch)
            mlflow.log_metric('val_loss', avg_val_loss, step=epoch)
            mlflow.log_metric('val_correct', val_accuracy, step=epoch)
            print(f'Epoch {epoch}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, '
                f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}')
    
    # Accumulate all predictions and labels across epochs
            cumulative_preds.extend(epoch_preds)
            cumulative_labels.extend(epoch_labels)
    
    save_path = f'./model/{model_name}.pth'
    torch.save(model.state_dict(), save_path)
    print(f'Model saved to {save_path}')

# Grad Cam

In [3]:
label_map = {
    0: "Surprised",
    1: "Fearful",
    2: "Disgusted",
    3: "Happy",
    4: "Sad",
    5: "Angry",
    6: "Neutral"
}

device = (
    "mps" 
    if torch.backends.mps.is_available() 
    else "cuda" 
    if torch.cuda.is_available() 
    else "cpu"
)
print(f'Using device: {device}')

# Load data
data_dict_path = './output_data/data_dict.pth'
data_dict = torch.load(data_dict_path)

Using device: mps


In [4]:
data_list = sum([data_dict[str(i)] for i in range(1, 8)], [])
total_size = len(data_list)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_data, val_data = random_split(data_list, [train_size, val_size])
# Create DataLoader
batch_size = 64
train_loader = GeoDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = GeoDataLoader(val_data, batch_size=batch_size, shuffle=False)

num_class = 7

# Model setup
num_node_features = next(data.x.shape[1] for data in data_list if data.x is not None)
hidden_channels = 64

In [None]:
print("Load model...")
model_alexnet_gnn = Net_Alex(num_node_features, hidden_channels)
model_alexnet_gnn.load_state_dict(torch.load('./model/model_Net_Alex.pth', map_location=torch.device(device)))
model_alexnet_gnn.eval()
model_alexnet_gnn.to(device)
print("Load model...")
model_resnet_gnn = Net_ResNet18(num_node_features, hidden_channels)
model_resnet_gnn.load_state_dict(torch.load('./model/model_Net_Resnet.pth', map_location=torch.device(device)))
model_resnet_gnn.eval()
model_resnet_gnn.to(device)
print("Load model...")
model_alexnet = AlexNet_Only(num_class)
model_alexnet.load_state_dict(torch.load('./model/model_alex_only.pth', map_location=torch.device(device)))
model_alexnet.eval()
model_alexnet.to(device)
print("Load model...")
model_resnet = ResNet18_Only(num_class)
model_resnet.load_state_dict(torch.load('./model/model_Resnet18_only.pth', map_location=torch.device(device)))
model_resnet.eval()
model_resnet.to(device)
print("Load model...")
model_vgg = VGG16_Only(num_class)
model_vgg.load_state_dict(torch.load('./model/model_VGG16_only.pth', map_location=torch.device(device)))
model_vgg.eval()
model_vgg.to(device)
print("Load model...")
model_vgg_gnn = Net_VGG(num_node_features, hidden_channels)
model_vgg_gnn.load_state_dict(torch.load('./model/model_Net_VGG.pth', map_location=torch.device(device)))
model_vgg_gnn.eval()
model_vgg_gnn.to(device)

Load model...
Load model...
Load model...
Load model...
Load model...
Load model...


In [15]:
# test_pth = [
#     'Image_data/DATASET/test/1/test_0002_aligned.jpg',
#     'Image_data/DATASET/test/2/test_0274_aligned.jpg',
#     'Image_data/DATASET/test/3/test_0007_aligned.jpg',
#     'Image_data/DATASET/test/4/test_0003_aligned.jpg',
#     'Image_data/DATASET/test/5/test_0001_aligned.jpg',
#     'Image_data/DATASET/test/6/test_0017_aligned.jpg',
#     'Image_data/DATASET/test/7/test_2389_aligned.jpg'
# ]
test_pth = [
    'Image_data/DATASET/test/1/test_0004_aligned.jpg',
    'Image_data/DATASET/test/2/test_0377_aligned.jpg',
    'Image_data/DATASET/test/3/test_0011_aligned.jpg',
    'Image_data/DATASET/test/4/test_0009_aligned.jpg',
    'Image_data/DATASET/test/5/test_0005_aligned.jpg',
    'Image_data/DATASET/test/6/test_0027_aligned.jpg',
    'Image_data/DATASET/test/7/test_2390_aligned.jpg'
]

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [16]:
for model in [model_alexnet_gnn, model_resnet_gnn, model_alexnet, model_resnet, model_vgg, model_vgg_gnn]:

    print(model.__class__.__name__)

    for index, pth in enumerate(test_pth):

        input_image = Image.open(pth)
        input_tensor = transform(input_image).unsqueeze(0)  # Add batch dimension
        rgb_img = np.array(input_image.resize((224, 224))) / 255.0

        edge_weight = data_dict['1'][0].edge_weight
        edge_index = data_dict['1'][0].edge_index
        batch = val_loader.dataset[0]

        if model.__class__.__name__ == 'Net_Alex':
            wrapped_model = NetWrapper(model, edge_index, edge_weight, batch)
            target_layers = [wrapped_model.model.alexnet.features[-1]]
        elif model.__class__.__name__ == 'Net_ResNet18':
            wrapped_model = NetWrapper(model, edge_index, edge_weight, batch)
            target_layers = [wrapped_model.model.resnet.layer4[-1]]
        elif model.__class__.__name__ == 'AlexNet_Only':
           target_layers = [model.alexnet.features[-1]]
           wrapped_model = model
        elif model.__class__.__name__ == 'ResNet18_Only':
            target_layers = [model.resnet.layer4[-1]]
            wrapped_model = model
        elif model.__class__.__name__ == 'VGG16_Only':
            target_layers = [model.vgg16.features[-1]]
            wrapped_model = model
        elif model.__class__.__name__ == 'Net_VGG':
            wrapped_model = NetWrapper(model, edge_index, edge_weight, batch)
            target_layers = [wrapped_model.model.vgg16.features[-1]]


        cam = GradCAM(model=wrapped_model, target_layers=target_layers)

        grayscale_cam = cam(input_tensor=input_tensor)

        visualization = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True)

        try:
            # save the plot
            plt.imsave(f'static/GradCam/{model.__class__.__name__}_{index + 1}_2.png', visualization)
        except FileNotFoundError:
            os.mkdir('static/GradCam')
            plt.imsave(f'static/GradCam/{model.__class__.__name__}_{index + 1}_2.png', visualization)
            

Net_Alex
Net_ResNet18
AlexNet_Only
ResNet18_Only
VGG16_Only
Net_VGG


# UMAP

In [None]:
for model in [model_alexnet_gnn, model_resnet_gnn, model_alexnet, model_resnet, model_vgg, model_vgg_gnn]:
    print(model.__class__.__name__)
    features_list = []
    labels_list = []

    with torch.no_grad():
        for batch in train_loader:
            print(batch)
            batch = batch.to(device)
            x = batch.x
            edge_index = batch.edge_index
            edge_weight = batch.edge_weight
            image_features = batch.image_features
            batch_y = batch.y

            if isinstance(model, Net_Alex) or isinstance(model, Net_ResNet18) or isinstance(model, Net_VGG):
                # 处理 GNN + 图像模型
                if x is not None and edge_index is not None:
                    # 调用 forward 获取 GNN 和 AlexNet 的嵌入特征
                    outputs, gnn_output = model(x, edge_index, edge_weight, batch.batch, image_features)
                    combined_features = torch.cat((outputs, gnn_output), dim=1)
                else:
                    # 没有 GNN 数据时，直接用图像特征
                    combined_features = model.alexnet(image_features)
            else:
                # 对于仅图像模型，直接获取图像特征
                combined_features = model(image_features)

            # 提取特征和标签
            features_list.append(combined_features.cpu().numpy())
            labels_list.extend(batch_y.cpu().numpy())

    # Combine features and labels
    features = np.vstack(features_list)
    labels = np.array(labels_list)

    # Dimensionality reduction with UMAP
    print("Performing UMAP dimensionality reduction...")
    reducer = umap.UMAP()
    low_dim_embeddings = reducer.fit_transform(features)

    # Visualization
    print("Visualizing data...")
    plt.figure(figsize=(12, 10))

    # Create scatter plot
    scatter = plt.scatter(
        low_dim_embeddings[:, 0], 
        low_dim_embeddings[:, 1], 
        c=labels, 
        cmap='tab10', 
        s=40, 
        edgecolor='k', 
        alpha=0.8
    )

    # Create a legend instead of a colorbar
    unique_labels = np.unique(labels)
    handles = [
        mpatches.Patch(color=scatter.cmap(scatter.norm(label)), label=label_map[label]) 
        for label in unique_labels
    ]
    plt.legend(
        handles=handles, 
        title="Emotion Categories", 
        loc='upper right', 
        fontsize=10, 
        title_fontsize=12
    )

    # Add titles and labels
    plt.xlabel('Dimension 1', fontsize=14)
    plt.ylabel('Dimension 2', fontsize=14)

    # Customize grid and background
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.gca().set_facecolor('#f7f7f7')

    # Save the plot
    output_dir = './UMAPplots'
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(f'{output_dir}/{model.__class__.__name__}.png', dpi=300, bbox_inches='tight')