In [4]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# Directory containing the images
image_dir = "D:\sem6\DL\cavallo"

# Read and preprocess images
def read_images(image_dir, size=(120, 120)):
    images = []
    for filename in os.listdir(image_dir):
        image_path = os.path.join(image_dir, filename)
        image = cv2.imread(image_path)
        if image is not None:
            image = cv2.resize(image, size)
            images.append(image)
    return images

images = read_images(image_dir)

# Function to divide an image into 9 equal pieces
def divide_image(image, piece_size=(40, 40)):
    m, n, _ = image.shape
    return [image[x:x+piece_size[0], y:y+piece_size[1]] for x in range(0, m, piece_size[0]) for y in range(0, n, piece_size[1])]

# Function to randomly shuffle image pieces and store original positions
def generate_combinations(parts, num_combinations=10):
    combinations = []
    original_positions = []
    for _ in range(num_combinations):
        idx = np.random.permutation(len(parts))
        combinations.append([parts[i] for i in idx])
        original_positions.append(idx)
    return combinations, original_positions

# CNN model to extract features from each piece
class PieceCNN(nn.Module):
    def __init__(self):
        super(PieceCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 5 * 5, 64)  # Adjust based on output size of conv3

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return x

# Graph network to predict piece positions
class GraphNetwork(nn.Module):
    def __init__(self):
        super(GraphNetwork, self).__init__()
        self.conv1 = GCNConv(64, 128)
        self.conv2 = GCNConv(128, 128)
        self.fc1 = nn.Linear(128, 1)  # Output for each piece

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # Aggregate features
        x = self.fc1(x)
        return x


# Data preparation
def prepare_data(images):
    input_data = []
    target_data = []
    for image in images:
        parts = divide_image(image)
        combinations, original_positions = generate_combinations(parts, 10)
        for idx, combination in enumerate(combinations):
            combination_tensor = torch.stack([torch.tensor(part).permute(2, 0, 1).float() / 255 for part in combination])
            input_data.append(combination_tensor)
            target_data.append(torch.tensor(original_positions[idx], dtype=torch.long))
    return input_data, target_data

input_data, target_data = prepare_data(images)

# # Create a DataLoader
# def create_data_loader(input_data, target_data):
#     graph_data_list = []
#     for data, target in zip(input_data, target_data):
#         x = data.view(-1, 3 * 40 * 40)  # Flatten each piece
#         y = target
#         edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)  # Dummy edge index
#         graph_data = Data(x=x, edge_index=edge_index, y=y)
#         graph_data_list.append(graph_data)
#     return DataLoader(graph_data_list, batch_size=1, shuffle=True)

# loader = create_data_loader(input_data, target_data)

# Initialize models and optimizer



In [5]:
# Adjust the data preparation to keep the spatial dimensions
# Create a DataLoader without flattening x; it should already be in [C, H, W] per piece
def create_data_loader(input_data, target_data):
    graph_data_list = []
    for data, target in zip(input_data, target_data):
        x = data  # Assume data is already in the shape [num_pieces, C, H, W]
        y = target
        edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)  # Dummy edge index
        graph_data = Data(x=x, edge_index=edge_index, y=y)
        graph_data_list.append(graph_data)
    return DataLoader(graph_data_list, batch_size=1, shuffle=True)

def train(loader, piece_cnn, graph_network, optimizer, num_epochs):
    piece_cnn.train()
    graph_network.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for data in loader:
            optimizer.zero_grad()
            embeddings = piece_cnn(data.x.view(-1, 3, 40, 40))
            output = graph_network(embeddings, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
            loss = F.cross_entropy(output, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(loader)}')


piece_cnn = PieceCNN()
graph_network = GraphNetwork()
optimizer = optim.Adam(list(piece_cnn.parameters()) + list(graph_network.parameters()), lr=0.001)

input_data, target_data = prepare_data(images)
loader = create_data_loader(input_data, target_data)
num_epochs = 10  # Define the number of epochs
train(loader, piece_cnn, graph_network, optimizer, num_epochs)
