In [31]:
import numpy as np 
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from skimage.metrics import structural_similarity as ssim
from torchvision.transforms import Normalize

In [32]:
def divide_image(image):
    parts = []
    height, width, _ = image.shape
    part_height = height // 3
    part_width = width // 3
    
    for i in range(3):
        for j in range(3):
            part = image[i*part_height:(i+1)*part_height, j*part_width:(j+1)*part_width]
            parts.append(part)
    
    return parts

def generate_combinations(parts, num_combinations):
    combinations = []
    original_positions = []
    indices = list(range(len(parts)))
    
    for _ in range(num_combinations):
        random.shuffle(indices)
        combination = [parts[i] for i in indices]
        combinations.append(combination)
        original_positions.append(indices.copy())
    
    return combinations, original_positions
def stitch_shuffled_image(parts):
    num_parts = len(parts)
    part_size = parts[0].shape[0]  # Assuming all parts are square
    
    stitched_image_size = int(np.sqrt(num_parts) * part_size)
    stitched_image = np.zeros((stitched_image_size, stitched_image_size, parts[0].shape[2]), dtype=np.uint8)
    
    for i in range(stitched_image.shape[0] // part_size):
        for j in range(stitched_image.shape[1] // part_size):
            part_index = i * int(stitched_image.shape[0] / part_size) + j
            stitched_image[i*part_size:(i+1)*part_size, j*part_size:(j+1)*part_size] = parts[part_index]
    
    return stitched_image
# checking if the target sequence is correct
def reconstructed_image(img, non_converted_target_data, test=False):
    
    sequenced = [0] * 9
    for i in range(0,27,3):

        # stack the 3 channels to get the original image
        blue_channel = img[:,:,i]
        green_channel = img[:,:,i+1]
        red_channel = img[:,:,i+2]

        # stack the 3 channels to get the original image
        tile = np.stack((blue_channel, green_channel, red_channel), axis=2)

        if test:
            sequenced[i // 3] = tile
        else:
            sequenced[non_converted_target_data[i // 3]] = tile

    stiched_img = stitch_shuffled_image(sequenced)
    return stiched_img


In [33]:
class JigsawDataset(Dataset):
    def __init__(self, input_data, target_data, transform=None):
        self.input_data = input_data
        self.target_data = target_data
        self.transform = transform
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, idx):
        image = self.input_data[idx]
        target = self.target_data[idx]
        
        # Convert NumPy arrays to float tensors
        image = torch.from_numpy(image).float()
        target = torch.from_numpy(target)
        
        if self.transform:
            image = self.transform(image)
        
        return image, target


In [34]:
# Define the JigsawModel class with batch normalization layers
class JigsawModel(nn.Module):
    def __init__(self):
        super(JigsawModel, self).__init__()
        self.conv1 = nn.Conv2d(27, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)  # Batch normalization layer
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)  # Batch normalization layer
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 10 * 10, 4096)
        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, 512) 
        self.fc4 = nn.Linear(512, 81)
        
    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 128 * 10 * 10)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        return x


In [35]:
class BoundaryLoss(nn.Module):
    def __init__(self):
        super(BoundaryLoss, self).__init__()

    def forward(self, outputs, labels):
        batch_size = outputs.size(0)
        top_bottom_ssim = torch.zeros(batch_size)
        left_right_ssim = torch.zeros(batch_size)

        for i in range(batch_size):
            # Reshape the outputs and labels to match SSIM function requirements
            output_img = outputs[i].view(9, 9).unsqueeze(0).unsqueeze(0)
            label_img = labels[i].view(9, 9).unsqueeze(0).unsqueeze(0)

            # Compute SSIM
            top_bottom_ssim[i] = 1 - torch.mean(torch.abs(output_img[0, :, 0, :] - label_img[0, :, -1, :]))
            left_right_ssim[i] = 1 - torch.mean(torch.abs(output_img[0, 0, :, :] - label_img[0, -1, :, :]))

        avg_tb_ssim = torch.mean(top_bottom_ssim)
        avg_lr_ssim = torch.mean(left_right_ssim)

        loss = avg_tb_ssim + avg_lr_ssim

        return loss


In [36]:
# Load and preprocess data
image_dir = "cavallo"

images = []
for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
    image = cv2.imread(image_path)
    
    if image is not None:
        image = cv2.resize(image, (120, 120))
        images.append(image)
input_data = []
target_data = []


In [37]:
for image in images:
    parts = divide_image(image)
    combinations, original_positions = generate_combinations(parts, 10)
    
    for idx, combination in enumerate(combinations):

        # shape of combination is (9, 40, 40, 3)
        combination = np.array(combination).transpose(0, 3, 1, 2)
        combination = np.concatenate(combination, axis=0).transpose(1, 2, 0)
        input_data.append(combination)

        dummy_target = np.zeros((9, 9), dtype=np.uint8)
        for i in range(9):
            dummy_target[i, original_positions[idx][i]] = 1

        target_data.append(dummy_target.flatten())

input_data = np.array(input_data)
target_data = np.array(target_data)

In [38]:
model = JigsawModel()
boundary_loss_fn = BoundaryLoss()
adversarial_loss_fn = nn.BCELoss()
weight_adversarial = 0.5  # Adjust as needed
weight_boundary = 0.5 
num_epochs = 10
batch_size = 64

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)



In [39]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(input_data, target_data, test_size=0.2, random_state=42)
print("shape of training data:", X_train.shape)



# Apply normalization transform to your dataset
train_dataset = JigsawDataset(X_train, y_train)
test_dataset = JigsawDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)



shape of training data: (20984, 40, 40, 27)


In [40]:
# Training loop
for epoch in range(num_epochs):
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.permute(0, 3, 1, 2).float()
        
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        outputs = torch.sigmoid(outputs)
        
        # Calculate adversarial loss
        adversarial_loss = adversarial_loss_fn(outputs, labels.float())
        b_loss = boundary_loss_fn(outputs, labels)
        total_loss = weight_adversarial * adversarial_loss + weight_boundary * b_loss
        # Apply gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        total_loss.backward()
        optimizer.step()
        
        running_loss += total_loss.item()
        
        if i % 100 == 99:
            print('Epoch: %d, Batch: %5d, Loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

        # Calculate accuracy
        _, predicted_indices = torch.max(outputs, 1)  # Get the index with the highest probability
        predicted_labels = torch.zeros_like(labels)
        for i, idx in enumerate(predicted_indices):
            predicted_labels[i, idx] = 1  # Convert index to one-hot encoding

        correct = (predicted_labels == labels).all(dim=1).sum().item()
        total_correct += correct
        total_samples += labels.size(0)

    accuracy = total_correct / total_samples
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss}, Accuracy: {accuracy}')


Epoch: 1, Batch:   100, Loss: 0.714
Epoch: 1, Batch:   200, Loss: 0.708
Epoch: 1, Batch:   300, Loss: 0.706
Epoch [1/10], Loss: 19.01365751028061, Accuracy: 0.0
Epoch: 2, Batch:   100, Loss: 0.703
Epoch: 2, Batch:   200, Loss: 0.700
Epoch: 2, Batch:   300, Loss: 0.699
Epoch [2/10], Loss: 18.83733904361725, Accuracy: 0.0
Epoch: 3, Batch:   100, Loss: 0.695
Epoch: 3, Batch:   200, Loss: 0.695
Epoch: 3, Batch:   300, Loss: 0.694
Epoch [3/10], Loss: 18.732250094413757, Accuracy: 0.0
Epoch: 4, Batch:   100, Loss: 0.691
Epoch: 4, Batch:   200, Loss: 0.690
Epoch: 4, Batch:   300, Loss: 0.690
Epoch [4/10], Loss: 18.629630029201508, Accuracy: 0.0
Epoch: 5, Batch:   100, Loss: 0.686
Epoch: 5, Batch:   200, Loss: 0.686
Epoch: 5, Batch:   300, Loss: 0.685
Epoch [5/10], Loss: 18.53717303276062, Accuracy: 0.0
Epoch: 6, Batch:   100, Loss: 0.682
Epoch: 6, Batch:   200, Loss: 0.683
Epoch: 6, Batch:   300, Loss: 0.681
Epoch [6/10], Loss: 18.378956079483032, Accuracy: 0.0
Epoch: 7, Batch:   100, Loss: 0

In [41]:
def recur(sequence, outputs):
    sequence = np.array(sequence)
    for i in range(len(sequence)):
        while np.sum(sequence == sequence[i]) > 1:  # If the current element is a duplicate
            # Find the index in outputs for the current sequence element that is not yet in the updated sequence
            scores = outputs[i]
            sorted_indices = np.argsort(scores)[::-1]  # Indices of scores sorted in descending order
            for idx in sorted_indices:
                if idx not in sequence:
                    sequence[i] = idx
                    break
                
    return sequence.tolist()

In [42]:
correct = 0
total = 0
per_tile_accuracy = 0

model.eval()
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.permute(0, 3, 1, 2).float()
        outputs = model(inputs)

        # reshape the output to 9x9 matrix
        outputs = outputs.reshape(-1, 9, 9)
        predicted = torch.argmax(outputs, dim=1)

        # now doing the same for the target
        labels = labels.reshape(-1, 9, 9)
        target = torch.argmax(labels, dim=1)

        # check if the predicted sequence is correct
        for i in range(len(predicted)):
            updated_predicted = recur(predicted[i], outputs[i, : , :].numpy())
            
            if torch.equal(torch.tensor(updated_predicted), target[i]):
                correct += 1
            total += 1

            per_tile_accuracy += (np.array(updated_predicted) == target[i].numpy()).sum() / 9

print('Accuracy on test images: %d %%' % (100 * correct / total))
print('Per tile accuracy on test images: %d %%' % (100 * per_tile_accuracy / total))

Accuracy on test images: 0 %
Per tile accuracy on test images: 18 %


In [44]:
import os

# Create the directory if it doesn't exist
os.makedirs('saved_models', exist_ok=True)

# Save the model
torch.save(model.state_dict(), 'saved_models/jigsaw_model.pth')


In [45]:
idx = random.randint(0, len(X_test))
image = X_test[idx]
target = y_test[idx]
inputs = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float()
outputs = model(inputs)
outputs = outputs.reshape(9, 9)
print(outputs)
converted_target = np.argmax(target.reshape(9, 9), axis=1)
print(converted_target)

tensor([[-4.0118e+00,  2.0453e+00,  2.2418e+00,  2.2821e+00,  2.5797e+00,
          2.6337e+00,  2.1445e+00,  2.1047e+00,  2.1770e+00],
        [-1.5043e+00, -2.8925e-01, -1.1173e+00, -1.2400e+00, -1.4567e+00,
         -1.2544e+00, -1.8207e+00, -2.1900e+00, -1.3938e+00],
        [-1.7818e+00, -1.1564e+00, -1.0048e+00, -6.2732e-01, -9.2533e-01,
         -1.6911e+00, -1.3251e+00, -1.5247e+00, -1.4480e+00],
        [-3.1206e+00, -2.0357e+00, -1.3962e+00, -2.1916e+00, -2.2762e+00,
         -4.7843e-01, -5.4281e-01, -3.9811e-03,  1.3640e-01],
        [-1.8509e+00, -1.0560e+00, -5.7689e-01, -4.8483e-01, -9.5806e-01,
         -1.4551e+00, -1.7789e+00, -1.6852e+00, -2.0565e+00],
        [-2.6980e+00, -2.1436e+00, -2.1027e+00, -7.1422e-01, -1.8187e+00,
         -9.5641e-01, -3.7571e-01, -4.4514e-03, -9.5942e-01],
        [-2.3109e+00, -1.0724e+00, -8.4616e-01, -1.3514e+00, -5.1938e-01,
         -8.6498e-01, -1.7045e+00, -1.5688e+00, -1.7525e+00],
        [-2.3116e+00, -1.4784e+00, -1.3319e+00, 