In [1]:

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES
# TO THE CORRECT LOCATION (/kaggle/input) IN YOUR NOTEBOOK,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'animals10:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F59760%2F840806%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240510%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240510T190515Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D7a2c571ee97a446b1f6318950094b35493748a15ab014c58c63c5285198e8310c153a90b3fea16dfac54cfc34c2c953174afbe5dc4fbabf85a6fc44b1ee9777082efb0003e2d6b723fb753af90fbb55744d7695c654f91b2f4fb66be956d308ff73e011e063de9ed6a32c3df5a7c952f4680ba1a3f0fadcbd8eee4d421752cbe64c583466c9a7b5df98f82d7059fee78968f33ed681b8c7191fdb41ff91f7d8fc9f2eefa6d2ccad66195929c6de7bc67e3c8d06b50b933bd8784dc30a7b3c3db1b49cd2f738c1f8caac53e1783a21da004ad1cb71574e228f19297eddb7b46dc66196505cff0eef49137424eb7f2def95337d5f358ee6206841cdb0cf2b7b9aa'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


Downloading animals10, 614087302 bytes compressed
Downloaded and uncompressed: animals10
Data source import complete.


In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
/kaggle/input/animals10/raw-img/ragno/OIP-1G-_JV28G6976mfXjykQcgHaHa.jpeg
/kaggle/input/animals10/raw-img/ragno/eb32b20e2cf0093ed1584d05fb1d4e9fe777ead218ac104497f5c97ca5ecb3b9_640.jpg
/kaggle/input/animals10/raw-img/ragno/OIP-I_ba5QeKVDW2RFKgPMPMQQHaL0.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-5vzzgyVyqHlx11Y2wGoeLwAAAA.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-qtAq3et2EPPSkXpntkf8cQHaFj.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-UWiebwKp2UeZSdlAsXW17gHaE8.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-hlFm3GH5UR57BJgqfaO-jQHaFL.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-5KDvxPqrokUhTtH06pik2gHaFs.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-4qLKRE0sl0HXpnMvthliGQHaGm.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-Q2uC4pz8shAh7njh3n6NqAAAAA.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-9DH7GCXH_VephQ5pUHJBgQHaE7.jpeg
/kaggle/input/animals10/raw-img/ragno/OIP-GZ_7XYpqqkrS1dS2JvZIKwHaFj.jpeg
/kaggle/in

In [3]:
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [52]:
image_dir = "/kaggle/input/animals10/raw-img/cavallo"

images = []
for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
    image = cv2.imread(image_path)

    if image is not None:
        image = cv2.resize(image, (120, 120))
        images.append(image)

print("Number of images read:", len(images))

Number of images read: 2623


In [53]:
def divide_image(image):
    parts = []
    height, width, _ = image.shape
    part_height = height // 3
    part_width = width // 3

    for i in range(3):
        for j in range(3):
            part = image[i*part_height:(i+1)*part_height, j*part_width:(j+1)*part_width]
            parts.append(part)

    return parts

In [54]:
def generate_combinations(parts, num_combinations):
    combinations = []
    original_positions = []
    indices = list(range(len(parts)))

    for _ in range(num_combinations):
        random.shuffle(indices)
        combination = [parts[i] for i in indices]
        combinations.append(combination)
        original_positions.append(indices.copy())

    return combinations, original_positions

In [55]:
def stitch_shuffled_image(parts):
    num_parts = len(parts)
    part_size = parts[0].shape[0]  # Assuming all parts are square

    stitched_image_size = int(np.sqrt(num_parts) * part_size)
    stitched_image = np.zeros((stitched_image_size, stitched_image_size, parts[0].shape[2]), dtype=np.uint8)

    for i in range(stitched_image.shape[0] // part_size):
        for j in range(stitched_image.shape[1] // part_size):
            part_index = i * int(stitched_image.shape[0] / part_size) + j
            stitched_image[i*part_size:(i+1)*part_size, j*part_size:(j+1)*part_size] = parts[part_index]

    return stitched_image

In [56]:

input_data = []
original_images = []
target_data = []

for image in images:
    parts = divide_image(image)
    combinations, original_positions = generate_combinations(parts, 10)

    for idx, combination in enumerate(combinations):
        shuffled_image = stitch_shuffled_image(combination)
        input_data.append(shuffled_image)
        original_images.append(image)  # Append the original image for reconstruction comparison

        dummy_target = np.zeros((9, 9), dtype=np.uint8)
        for i in range(9):
            dummy_target[i, original_positions[idx][i]] = 1

        target_data.append(dummy_target.flatten())
# Assuming dummy_target is 9x9 and represents a grid where only one element is '1' in each row.
converted_target_data = [np.argmax(dummy_target.reshape(-1, 9), axis=1) for dummy_target in target_data]
target_data = np.array([np.argmax(row) for row in converted_target_data])  # Convert to single indices


original_images = np.array(original_images)

In [57]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim

# class JigsawVAE(nn.Module):
#     def __init__(self):
#         super(JigsawVAE, self).__init__()
#         # Encoder
#         self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
#         self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
#         self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)

#         # Fully connected layers for classification
#         self.fc1 = nn.Linear(64 * 32 * 32, 512)
#         self.fc2 = nn.Linear(512, 9)  # Assuming 9 classes for the 3x3 puzzle

#         # Decoder
#         self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
#         self.deconv2 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
#         self.deconv3 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

#     def forward(self, x):
#         # Encode
#         x = F.relu(self.conv1(x))
#         x = F.relu(self.conv2(x))
#         encoded = F.relu(self.conv3(x))

#         # Classification
#         flat = torch.flatten(encoded, start_dim=1)
#         fc = F.relu(self.fc1(flat))
#         classes = self.fc2(fc)

#         # Decode
#         x = F.relu(self.deconv1(encoded))
#         x = F.relu(self.deconv2(x))
#         reconstructed = torch.sigmoid(self.deconv3(x))  # Use sigmoid if inputs are normalized to [0,1]

#         return classes, reconstructed

# model = JigsawVAE()
class JigsawVAE(nn.Module):
    def __init__(self):
        super(JigsawVAE, self).__init__()
        # Encoder
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)  # Output size: (16, 128, 128) for 256x256 input
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) # Output size: (32, 64, 64)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # Output size: (64, 32, 32)

        # Adjust the number of flat features according to the actual output size
        self.num_flat_features = 64 * 15 * 15  # Corrected to match the output from conv3: 64 channels, each 15x15

        # Fully connected layers for classification
        self.fc1 = nn.Linear(self.num_flat_features, 512)
        self.fc2 = nn.Linear(512, 9)  # Assuming 9 classes for the 3x3 puzzle

        # Decoder
        self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        # Encode
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # Flatten for the fully connected layers
        x = x.view(-1, self.num_flat_features)

        # Classification
        fc = F.relu(self.fc1(x))
        classes = self.fc2(fc)

        # Decode
        x = x.view(-1, 64, 15, 15)  # Reshape x to the expected shape for deconvolution
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        reconstructed = torch.sigmoid(self.deconv3(x))  # Use sigmoid if inputs are normalized to [0,1]

        return classes, reconstructed


In [58]:
# class JigsawDataset(Dataset):
#     def __init__(self, input_data, target_data, original_images, transform=None):
#         self.input_data = input_data
#         self.target_data = target_data  # This is still one-hot encoded here
#         self.original_images = original_images
#         self.transform = transform

#     def __len__(self):
#         return len(self.input_data)

#     def __getitem__(self, idx):
#         shuffled_image = self.input_data[idx]
#         original_image = self.original_images[idx]

#         if isinstance(shuffled_image, np.ndarray):
#             shuffled_image = transforms.ToPILImage()(shuffled_image)
#             original_image = transforms.ToPILImage()(original_image)

#         if self.transform:
#             shuffled_image = self.transform(shuffled_image)
#             original_image = self.transform(original_image)

#         # Convert one-hot encoding to index
#         one_hot_vector = self.target_data[idx]
#         class_index = np.argmax(one_hot_vector)  # Convert from one-hot to index
#         class_indices = torch.tensor(class_index, dtype=torch.long)  # Ensure it's a single value

#         return shuffled_image, original_image, class_indices

class JigsawDataset(Dataset):
    def __init__(self, input_data, target_data, original_images, transform=None):
        """
        input_data: List of shuffled images (pieces assembled in wrong order)
        target_data: List of targets indicating the correct position of each piece.
                     This should be a list of indices, not one-hot encoded.
        original_images: List of original unshuffled images for reconstruction
        transform: Optional transform to be applied on a sample.
        """
        self.input_data = input_data
        self.target_data = target_data
        self.original_images = original_images
        self.transform = transform

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
      shuffled_image = self.input_data[idx]
      original_image = self.original_images[idx]

      if isinstance(shuffled_image, np.ndarray):  # Ensure image is in correct format
          shuffled_image = transforms.ToPILImage()(shuffled_image)
          original_image = transforms.ToPILImage()(original_image)

      if self.transform:
          shuffled_image = self.transform(shuffled_image)
          original_image = self.transform(original_image)

      # Ensure class_indices is a 1D tensor of type long for cross_entropy
      class_indices = torch.tensor(self.target_data[idx], dtype=torch.long)

      return shuffled_image, original_image, class_indices




# Define transformations if necessary
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


In [59]:
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split

# Assuming JigsawDataset is defined to take input_data, original_images, and target_data

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Ensures input is converted to a tensor and scaled to [0, 1]
    # Add any necessary resizing or normalization here if required
])

# Split data into training and testing sets, including the original_images
X_train, X_test, y_train, y_test, original_images_train, original_images_test = train_test_split(
    input_data, target_data, original_images, test_size=0.2, random_state=42)

# Create instances of the JigsawDataset including original images for VAE reconstruction
train_dataset = JigsawDataset(X_train, y_train, original_images_train, transform=transform)
test_dataset = JigsawDataset(X_test, y_test, original_images_test, transform=transform)
batch_size = 16
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [66]:
# Ensure your loss function expects the correct parameters:
# model = JigsawVAE()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JigsawVAE().to(device)
def loss_function(class_pred, target_classes, recon_x, original_x):
    class_indices = torch.argmax(class_pred, dim=1)

    # Convert class indices to one-hot encoding
      # total number of classes
    one_hot_class = F.one_hot(class_indices, num_classes=81)

    classification_loss = F.cross_entropy(one_hot_class.float(), target_classes)
    reconstruction_loss = F.mse_loss(recon_x, original_x)
    return classification_loss + reconstruction_loss

# Example training loop call:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    for batch_idx, (data, original, classes) in enumerate(train_loader):
      data, original, classes = data.to(device), original.to(device), classes.to(device)
      optimizer.zero_grad()
      class_pred, recon_batch = model(data)
      loss = loss_function(class_pred, classes, recon_batch, original)
      loss.backward()
      #print('losss',loss)
      optimizer.step()
      #print('hi')

      if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

      print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000
====> Epoch: 6 Average loss: 0.0000

In [68]:
    def calculate_accuracy(y_pred, y_true):
      _, predicted = torch.max(y_pred, 1)
      correct = (predicted == y_true).sum().item()
      return correct / y_true.size(0)
    model.eval()
    valid_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, original, classes in test_loader:
            data, original, classes = data.to(device), original.to(device), classes.to(device)
            class_pred, recon_batch = model(data)
            valid_loss += loss_function(class_pred, classes, recon_batch, original).item()
            correct += calculate_accuracy(class_pred, classes) * data.size(0)
            total += data.size(0)

    valid_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / total
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        valid_loss, correct, total, accuracy))


Validation set: Average loss: 0.2720, Accuracy: 599.0/5232 (11%)



In [69]:
torch.save(model.state_dict(), 'model_vae.pth')
print('Saved model to "model_vae.pth"')

Saved model to "model_vae.pth"
