In [1]:
cd /home/CAMPUS/hdasari/Hypernetworks_stevens

/home/CAMPUS/hdasari/Hypernetworks_stevens


In [2]:
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import glob
from torchvision import transforms
from pathlib import Path

from src.utils.hyp_input import hyp_input
from src.utils.get_boundary_pixels import get_boundary_pixels
from src.utils.extract_patch import extract_patch


class Battery_unet_hyp_data(Dataset):
    def __init__(self, image_dir,unet_model, device, mask_function=hyp_input, get_boundaries=get_boundary_pixels, get_patch=extract_patch, transform=None):
        self.image_dir = image_dir
        # self.label_dir = label_dir
        self.mask_function = mask_function
        self.get_boundaries = get_boundaries
        self.get_patch = get_patch
        self.transform = transform
        self.unet_model = unet_model
        self.device = device
        
        self.image_files = sorted(Path(image_dir).glob('*.png'))
        
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]

        # Load and transform images
        image = Image.open(image_path).convert('L')
        img_ndarray = np.asarray(image)
        img_ndarray = img_ndarray[np.newaxis, ...]  # Add channel dimension [1, H, W]
        image_tensor = torch.as_tensor(img_ndarray / 255.0).float().contiguous()
        
        with torch.no_grad():
            input_img = image_tensor.unsqueeze(0).to(self.device)  # Add batch dimension
            label_tensor = self.unet_model(input_img)
            label_tensor = torch.argmax(label_tensor, dim = 1)
            print()
        label_tensor = label_tensor.squeeze(0).cpu().type(torch.long)
        

        _, H, W = image_tensor.shape

        # Get key pixels and masked image
        key_pixels, masked_image = self.mask_function(label_tensor)

        all_patches = []
        all_labels = []

        mismatch = 0

        for x, y in key_pixels:
            boundary_pixels = self.get_boundaries(x, y)

            patches = []
            labels = []

            for bx, by in boundary_pixels:
                if bx < 0 or by < 0 or bx >= H or by >= W:
                    patches.append(torch.zeros((1, 9, 9), dtype=torch.long))
                    labels.append(255)
                    mismatch += 1
                else:
                    patches.append(self.get_patch(image_tensor, bx, by))
                    labels.append(label_tensor[bx, by])  # Get label ID

            all_patches.append(torch.stack(patches))  # Shape: (max_boundaries, C, H, W)
            all_labels.append(torch.tensor(labels, dtype=torch.long))

        # Convert lists to tensors
        all_patches = torch.stack(all_patches)  # Shape: (num_key_pixels, max_boundaries, C, H, W)
        all_labels = torch.stack(all_labels)  # Shape: (num_key_pixels, max_boundaries)

        return all_patches, masked_image, key_pixels, all_labels, mismatch


In [3]:
import torch
from torch.utils.data import DataLoader
from src.models.Unet_model import UNet



device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


unet_path = "/home/CAMPUS/sgangadh1/projects/rl-batt-seg-snapshot-jan-2024/src/outputs/rerun-battery-01/unet_model_checkpoint_finetuned.pt"

model_unet = UNet(n_channels=1, n_classes=3)
model_unet = model_unet.to(device)

checkpoint = torch.load(unet_path)

model_unet.load_state_dict(checkpoint['model_state_dict'])
model_unet.eval()

Using device: cuda:3


  checkpoint = torch.load(unet_path)


UNet(
  (softconv): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), s

In [4]:


image_dir = "/home/CAMPUS/sgangadh1/projects/rl-batt-seg-snapshot-jan-2024/data/battery_2/train_images"

# Load dataset
train_dataset = Battery_unet_hyp_data(image_dir, model_unet,device)


# Iterate through the DataLoader and inspect the output
def display_first_batch_shape(dataset, batch_size=1):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for batch in dataloader:
        all_patches, masked_image, key_pixels, all_labels, mismatch = batch
        print(f"all_patches length: {len(all_patches)}") 
        print(f"all_patches shape: {(all_patches.shape)}") # Expected: (batch_size, K, 40, C, H, W)
        print(f"masked_image shape: {masked_image.shape}")  # Expected: (batch_size, k,concatednated_patch[9*9+8])
        print(f"key_pixels shape: {key_pixels.shape}")  # Expected: (batch_size, K,2)
        print(f"all_labels shape: {all_labels.shape}")  # Expected: (batch_size, K, 40)
        print(f"Mismatches: {mismatch}")
        break
        
        

       
display_first_batch_shape(train_dataset, batch_size=16)


















all_patches length: 16
all_patches shape: torch.Size([16, 3, 40, 1, 9, 9])
masked_image shape: torch.Size([16, 3, 89])
key_pixels shape: torch.Size([16, 3, 2])
all_labels shape: torch.Size([16, 3, 40])
Mismatches: tensor([ 0,  0,  0,  0, 13,  0, 17,  0,  0,  0,  0,  0,  0,  0, 11, 19])
