In [1]:
import sys
sys.path.append('../..')

In [2]:
import torch

if torch.cuda.is_available():
    device = 'cuda'
if torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [3]:
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset
from torchvision import transforms

from src.dataset.footprint2pressure import Footprint2Pressure

## Sensor stacks

In [31]:
class Footprint2Pressure_SensorStack(Footprint2Pressure):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # youngs modulus (MPa)
        self.material_youngs = {
            'Poron': 0.33,
            'PElite': 1.11,
            'Lunalight': 5.88,
            'Lunanastik': 0.71,
            'BF': 0.00,
        }

        self.resize = transforms.Resize((self.img_size, self.img_size))

    def __getitem__(self, idx: int) -> tuple:
        # get subject
        material = self.index[idx][0]
        young = torch.tensor(self.material_youngs[material], dtype=self.dtype)
        subject = self.index[idx][1]
        
        # get young modulus & pedar arrays
        arr_pedar = self.pedar_dynamic.loc[material, subject].values / self.sense_range
        pedar_t = torch.tensor(arr_pedar, dtype=self.dtype)

        # load footprint image and slice as per-sensor stacks
        def get_img_stack(foot: str):
            img = Image.open(self.footprint_wrap_folder / f'{subject}-{foot}.jpg')
            img_arr = np.mean(1 - np.array(img).astype(np.float64) / 255, axis=-1)
            img_stack = img_arr[self.x_grid[foot], self.y_grid[foot]]
            img_stack = torch.tensor(img_stack, dtype=self.dtype)
            img_stack = self.resize(img_stack)
            return img_stack
        
        l_stack = get_img_stack('L')
        r_stack = get_img_stack('R')
        img_stack = torch.concat([l_stack, r_stack])

        # remember to move data to device!
        return (img_stack.to(self.device), young.to(self.device)), pedar_t.to(self.device)

In [32]:
self = Footprint2Pressure_SensorStack(
    device = device,
    footprint_wrap_folder = '../../data/processed/footprint-wrap',
    pedar_dynamic_path = '../../data/processed/pedar_dynamic.pkl',
    l_mask_path = '../../data/processed/left_foot_mask.png',
)
len(self)

250

In [33]:
(img_stack, young), pedar_t = self[100]
img_stack.shape, young, pedar_t.shape

(torch.Size([198, 10, 10]), tensor(5.8800, device='mps:0'), torch.Size([198]))

## Patch stacks