In [1]:
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from skimage.metrics import structural_similarity as ssim
from torchvision.transforms import Normalize

In [2]:
def divide_image(image):
    parts = []
    height, width, _ = image.shape
    part_height = height // 3
    part_width = width // 3

    for i in range(3):
        for j in range(3):
            part = image[i*part_height:(i+1)*part_height, j*part_width:(j+1)*part_width]
            parts.append(part)

    return parts

def generate_combinations(parts, num_combinations):
    combinations = []
    original_positions = []
    indices = list(range(len(parts)))

    for _ in range(num_combinations):
        random.shuffle(indices)
        combination = [parts[i] for i in indices]
        combinations.append(combination)
        original_positions.append(indices.copy())

    return combinations, original_positions
def stitch_shuffled_image(parts):
    num_parts = len(parts)
    part_size = parts[0].shape[0]  # Assuming all parts are square

    stitched_image_size = int(np.sqrt(num_parts) * part_size)
    stitched_image = np.zeros((stitched_image_size, stitched_image_size, parts[0].shape[2]), dtype=np.uint8)

    for i in range(stitched_image.shape[0] // part_size):
        for j in range(stitched_image.shape[1] // part_size):
            part_index = i * int(stitched_image.shape[0] / part_size) + j
            stitched_image[i*part_size:(i+1)*part_size, j*part_size:(j+1)*part_size] = parts[part_index]

    return stitched_image
# checking if the target sequence is correct
def reconstructed_image(img, non_converted_target_data, test=False):

    sequenced = [0] * 9
    for i in range(0,27,3):

        # stack the 3 channels to get the original image
        blue_channel = img[:,:,i]
        green_channel = img[:,:,i+1]
        red_channel = img[:,:,i+2]

        # stack the 3 channels to get the original image
        tile = np.stack((blue_channel, green_channel, red_channel), axis=2)

        if test:
            sequenced[i // 3] = tile
        else:
            sequenced[non_converted_target_data[i // 3]] = tile

    stiched_img = stitch_shuffled_image(sequenced)
    return stiched_img


In [3]:
class JigsawDataset(Dataset):
    def __init__(self, input_data, target_data, transform=None):
        self.input_data = input_data
        self.target_data = target_data
        self.transform = transform

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        image = self.input_data[idx]
        target = self.target_data[idx]

        # Convert NumPy arrays to float tensors
        image = torch.from_numpy(image).float()
        target = torch.from_numpy(target)

        if self.transform:
            image = self.transform(image)

        return image, target


In [4]:
# Define the JigsawModel class with batch normalization layers
class JigsawModel(nn.Module):
    def __init__(self):
        super(JigsawModel, self).__init__()
        self.conv1 = nn.Conv2d(27, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)  # Batch normalization layer
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)  # Batch normalization layer
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 10 * 10, 4096)
        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 81)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.reshape(-1, 128 * 10 * 10)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        return x


In [5]:
class BoundaryLoss(nn.Module):
    def __init__(self):
        super(BoundaryLoss, self).__init__()

    def forward(self, outputs, labels):
        batch_size = outputs.size(0)
        top_bottom_ssim = torch.zeros(batch_size)
        left_right_ssim = torch.zeros(batch_size)

        for i in range(batch_size):
            # Reshape the outputs and labels to match SSIM function requirements
            output_img = outputs[i].view(9, 9).unsqueeze(0).unsqueeze(0)
            label_img = labels[i].view(9, 9).unsqueeze(0).unsqueeze(0)

            # Compute SSIM
            top_bottom_ssim[i] = 1 - torch.mean(torch.abs(output_img[0, :, 0, :] - label_img[0, :, -1, :]))
            left_right_ssim[i] = 1 - torch.mean(torch.abs(output_img[0, 0, :, :] - label_img[0, -1, :, :]))

        avg_tb_ssim = torch.mean(top_bottom_ssim)
        avg_lr_ssim = torch.mean(left_right_ssim)

        loss = avg_tb_ssim + avg_lr_ssim

        return loss


In [7]:
# Load and preprocess data
# from google.colab import drive
# drive.mount('/content/drive/')
# image_dir = "/content/drive/MyDrive/cavallo"
image_dir = "cavallo"
images = []
for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
    image = cv2.imread(image_path)

    if image is not None:
        image = cv2.resize(image, (120, 120))
        images.append(image)
input_data = []
target_data = []


In [8]:
for image in images:
    parts = divide_image(image)
    combinations, original_positions = generate_combinations(parts, 10)

    for idx, combination in enumerate(combinations):

        # shape of combination is (9, 40, 40, 3)
        combination = np.array(combination).transpose(0, 3, 1, 2)
        combination = np.concatenate(combination, axis=0).transpose(1, 2, 0)
        input_data.append(combination)

        dummy_target = np.zeros((9, 9), dtype=np.uint8)
        for i in range(9):
            dummy_target[i, original_positions[idx][i]] = 1

        target_data.append(dummy_target.flatten())

input_data = np.array(input_data)
target_data = np.array(target_data)

In [9]:
# model = JigsawModel()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
model = JigsawModel().to(device)
model.load_state_dict(torch.load('bloss_model_38_50.pth'))
boundary_loss_fn = BoundaryLoss()
adversarial_loss_fn = nn.CrossEntropyLoss()
weight_adversarial = 0.5  # Adjust as needed
weight_boundary = 0.5
num_epochs = 50
batch_size = 64

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)



In [10]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(input_data, target_data, test_size=0.2, random_state=42)
print("shape of training data:", X_train.shape)



# Apply normalization transform to your dataset
train_dataset = JigsawDataset(X_train, y_train)
test_dataset = JigsawDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)



shape of training data: (20984, 40, 40, 27)


'cuda'

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)  # Move model to GPU if available

# adversarial_loss_fn = nn.CrossEntropyLoss()  # Use Cross Entropy Loss for multi-class classification

# for epoch in range(num_epochs):
#     running_loss = 0.0
#     total_correct = 0
#     total_samples = 0
#     for i, data in enumerate(train_loader, 0):
#         inputs, labels = data
#         inputs = inputs.permute(0, 3, 1, 2).float().to(device)  # Move inputs to GPU
#         labels = labels.float().to(device)  # Move labels to GPU
#         # print(labels.shape)
#         optimizer.zero_grad()

#         outputs = model(inputs)
#         # print(outputs.shape)
#         # Calculate adversarial loss
#         adversarial_loss = adversarial_loss_fn(outputs, labels)

#         b_loss = 0 #boundary_loss_fn(outputs, labels)
#         total_loss = weight_adversarial * adversarial_loss + weight_boundary * b_loss
#         # Apply gradient clipping
#         nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

#         total_loss.backward()
#         optimizer.step()

#         running_loss += total_loss.item()

#         if i % 100 == 99:
#             print('Epoch: %d, Batch: %5d, Loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
#             running_loss = 0.0

#         # Calculate accuracy
#         _, predicted_indices = torch.max(outputs, 1)  # Get the index with the highest probability
#         # print(predicted_indices.shape)
#         # print(labels.shape)
#         correct = (predicted_indices == labels.argmax(dim=1)).sum().item()
#         total_correct += correct
#         total_samples += labels.size(0)

#     accuracy = total_correct / total_samples
#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss}, Accuracy: {accuracy}')


Epoch: 1, Batch:   100, Loss: 19.701
Epoch: 1, Batch:   200, Loss: 19.261
Epoch: 1, Batch:   300, Loss: 18.746
Epoch [1/50], Loss: 500.23755645751953, Accuracy: 0.02441704892966361
Epoch: 2, Batch:   100, Loss: 18.264
Epoch: 2, Batch:   200, Loss: 18.015
Epoch: 2, Batch:   300, Loss: 17.767
Epoch [2/50], Loss: 479.5336856842041, Accuracy: 0.04544151376146789
Epoch: 3, Batch:   100, Loss: 17.452
Epoch: 3, Batch:   200, Loss: 17.409
Epoch: 3, Batch:   300, Loss: 17.305
Epoch [3/50], Loss: 468.5699882507324, Accuracy: 0.05948967889908257
Epoch: 4, Batch:   100, Loss: 17.015
Epoch: 4, Batch:   200, Loss: 17.023
Epoch: 4, Batch:   300, Loss: 16.962
Epoch [4/50], Loss: 455.6283531188965, Accuracy: 0.06474579510703364
Epoch: 5, Batch:   100, Loss: 16.681
Epoch: 5, Batch:   200, Loss: 16.672
Epoch: 5, Batch:   300, Loss: 16.617
Epoch [5/50], Loss: 449.9295406341553, Accuracy: 0.06880733944954129
Epoch: 6, Batch:   100, Loss: 16.347
Epoch: 6, Batch:   200, Loss: 16.353
Epoch: 6, Batch:   300, L

KeyboardInterrupt: 

In [11]:

model.eval()

JigsawModel(
  (conv1): Conv2d(27, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=12800, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=512, bias=True)
  (fc4): Linear(in_features=512, out_features=81, bias=True)
)

In [12]:
def recur(sequence, outputs):
    sequence = np.array(sequence)
    for i in range(len(sequence)):
        while np.sum(sequence == sequence[i]) > 1:  # If the current element is a duplicate
            # Find the index in outputs for the current sequence element that is not yet in the updated sequence
            scores = outputs[i]
            sorted_indices = np.argsort(scores)[::-1]  # Indices of scores sorted in descending order
            for idx in sorted_indices:
                if idx not in sequence:
                    sequence[i] = idx
                    break

    return sequence.tolist()

In [13]:
correct = 0
total = 0
per_tile_accuracy = 0
################
model.to("cpu")
model.eval()
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.permute(0, 3, 1, 2).float()
        outputs = model(inputs)

        # reshape the output to 9x9 matrix
        outputs = outputs.reshape(-1, 9, 9)
        predicted = torch.argmax(outputs, dim=1)

        # now doing the same for the target
        labels = labels.reshape(-1, 9, 9)
        target = torch.argmax(labels, dim=1)

        # check if the predicted sequence is correct
        for i in range(len(predicted)):
            updated_predicted = recur(predicted[i], outputs[i, : , :].numpy())

            if torch.equal(torch.tensor(updated_predicted), target[i]):
                correct += 1
            total += 1

            per_tile_accuracy += (np.array(updated_predicted) == target[i].numpy()).sum() / 9

print('Accuracy on test images: %d %%' % (100 * correct / total))
print('Per tile accuracy on test images: %d %%' % (100 * per_tile_accuracy / total))

Accuracy on test images: 3 %
Per tile accuracy on test images: 40 %


In [None]:
# import os

# # Create the directory if it doesn't exist
# os.makedirs('saved_models', exist_ok=True)

# # Save the model
# torch.save(model.state_dict(), 'saved_models/bloss_model_38.pth')


In [None]:
# from google.colab import drive

# # Mount Google Drive
# drive.mount('/content/drive')

# # Specify the directory where you want to save the model
# save_dir = '/content/drive/My Drive/saved_models/'

# # Create the directory if it doesn't exist
# os.makedirs(save_dir, exist_ok=True)

# # Save the model
# torch.save(model.state_dict(), os.path.join(save_dir, 'bloss_model_38_50.pth'))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
idx = random.randint(0, len(X_test))
image = X_test[idx]
target = y_test[idx]
inputs = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float()
outputs = model(inputs)
outputs = outputs.reshape(9, 9)
print(outputs)
converted_target = np.argmax(target.reshape(9, 9), axis=1)
print(converted_target)

tensor([[-2.5402e+00, -1.6388e+00, -2.3079e+00, -3.0343e+00, -2.2356e+00,
          4.6958e-01, -4.1293e+00, -1.4466e+00,  5.9697e-01],
        [-1.6023e+00, -2.6041e+00, -4.5954e-01, -1.2689e+00, -2.2203e+00,
         -3.1179e-01,  2.7006e-01, -2.7077e+00,  1.4654e+00],
        [-3.2233e+00, -6.1245e+00, -3.7110e+00, -3.7112e-01,  2.7203e+00,
         -1.5710e+00, -1.5550e+00, -2.3751e+00, -9.8417e-01],
        [-1.4831e+00,  1.2965e+00,  1.8673e+00, -6.8744e+00, -6.3272e+00,
         -5.9607e+00, -8.7382e+00, -1.4744e+01, -1.2202e+01],
        [ 7.6922e-01,  1.8393e+00, -1.2242e+00, -2.6569e+00, -9.7774e+00,
         -7.0908e+00, -6.5684e+00, -1.2753e+01, -9.3459e+00],
        [-7.6305e-01, -4.0106e+00, -4.5658e+00,  9.3758e-01, -5.0661e-01,
         -3.9638e+00, -3.9317e+00, -2.3472e+00, -5.6322e+00],
        [-5.1704e+00, -7.8025e+00, -3.2860e+00, -4.0859e+00, -2.8290e+00,
          8.3692e-03, -2.5862e+00,  3.3608e-01, -3.5263e-01],
        [-7.1278e+00, -6.4387e+00, -6.3696e+00, 

In [15]:
def recur(sequence, outputs):
    sequence = np.array(sequence)
    for i in range(len(sequence)):
        while np.sum(sequence == sequence[i]) > 1:  # If the current element is a duplicate
            # Find the index in outputs for the current sequence element that is not yet in the updated sequence
            scores = outputs[i]
            sorted_indices = np.argsort(scores)[::-1]  # Indices of scores sorted in descending order
            for idx in sorted_indices:
                if idx not in sequence:
                    sequence[i] = idx
                    break
                
    return sequence.tolist()

In [16]:
def stitch_shuffled_image(parts):
    num_parts = len(parts)
    part_size = parts[0].shape[0]  # Assuming all parts are square
    
    stitched_image_size = int(np.sqrt(num_parts) * part_size)
    stitched_image = np.zeros((stitched_image_size, stitched_image_size, parts[0].shape[2]), dtype=np.uint8)
    
    for i in range(stitched_image.shape[0] // part_size):
        for j in range(stitched_image.shape[1] // part_size):
            part_index = i * int(stitched_image.shape[0] / part_size) + j
            stitched_image[i*part_size:(i+1)*part_size, j*part_size:(j+1)*part_size] = parts[part_index]
    
    return stitched_image

In [32]:
import cv2
import random
import torch
import numpy as np

# Assuming reconstructed_image and recur functions are defined elsewhere

# Select random image from test set
idx = random.randint(0, len(X_test))
image = X_test[idx]
target = y_test[idx]

# Plot only the reconstructed image with predicted sequence
inputs = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float()
outputs = model(inputs)
outputs = outputs.reshape(9, 9)

predicted = torch.argmax(outputs, dim=1)
updated_predicted = recur(predicted, outputs.detach().numpy())

reconstructed_img = reconstructed_image(image, updated_predicted)

# Convert the image to BGR format as cv2 uses BGR by default
reconstructed_img_bgr = cv2.cvtColor(reconstructed_img, cv2.COLOR_RGB2BGR)

# Save the reconstructed image with the predicted sequence
cv2.imwrite('reconstructed_image_with_predicted_sequence_5.png', reconstructed_img_bgr)


True