In [None]:
%run ../supportvectors-common.ipynb

In [2]:
import torch
from torchvision.transforms import v2
from svlearn.config.configuration import ConfigurationMixin
                                                        
from svlearn.auto_encoders.resnet_auto_encoder import ResNetAutoEncoder

In [3]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
config = ConfigurationMixin().load_config()
results_dir = config['tree-classification']['results']

In [5]:
test_transform = v2.Compose([
    v2.ToImage(), 
    v2.Resize(size=(224 , 224)),  # resize all images to a standard size suitable for the cnn model
    v2.ToDtype(torch.float32, scale=True), # ensure te tensor is of float datatype
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalize tensor 
])

In [6]:
import numpy as np
def generate_mask(images, patch_size=16, mask_ratio=0.2):
    """Generates random masks for images by replacing certain patches with zeros."""
    batch_size, _, height, width = images.shape
    mask = torch.ones_like(images)
    num_patches = (height // patch_size) * (width // patch_size)
    num_masked = int(mask_ratio * num_patches)
    
    rng = np.random.default_rng(42)
    for i in range(batch_size):
        mask_indices = rng.choice(num_patches, num_masked, replace=False)
        for idx in mask_indices:
            row = (idx // (width // patch_size)) * patch_size
            col = (idx % (width // patch_size)) * patch_size
            mask[i, :, row:row+patch_size, col:col+patch_size] = 0
            
    return images * mask

In [14]:
# Load the images with some artificially induced noise
import os
from PIL import Image
image_path = '/home/chandar/images'
images = []
for filename in os.listdir(image_path):
    file_path = os.path.join(image_path, filename)
    img = Image.open(file_path).convert('RGB')  # Ensure 3-channel RGB format
    # Apply transformations
    img_transformed = test_transform(img)
    masked_image = generate_mask(img_transformed.unsqueeze(0), mask_ratio=0.0).squeeze(0)
    images.append(masked_image)


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ResNetAutoEncoder().to(device)

In [9]:
checkpoint = torch.load(f"{results_dir}/trees_resnet50_masked_autoencoder.pt")

In [None]:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
import matplotlib.pyplot as plt
from svlearn.auto_encoders.auto_encoder_util import convert
with torch.no_grad():
    num_images = len(images)
    images = [img.to(device) for img in images]
    images = torch.stack(images)
    reconstructed, _ = model(images)
    
    # Display original and reconstructed images
    _, axs = plt.subplots(2, num_images, figsize=(10, 4))
    for i in range(num_images):
        axs[0, i].imshow(convert(images[i].cpu()))
        axs[1, i].imshow(convert(reconstructed[i].cpu()))
    plt.show()