In [10]:
from dataset import MnistDataset

In [13]:
from torch.utils.data import DataLoader

train_dataset = MnistDataset(
    num_instances=2,
    num_samples_per_class=16,
    digit_arr=list(range(0,10)),
    ucc_start=1,
    ucc_end=4,
    mode="train",
    length = 80000
)
train_loader = DataLoader(train_dataset, batch_size=32)

x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples


In [14]:
for a, b in train_loader:
    print(a.shape)
    break

torch.Size([32, 2, 1, 28, 28])


In [16]:
import torch
import numpy as np

data_dir = "../data/mnist/splitted_mnist_dataset.npz"
splitted_dataset = np.load(data_dir)


In [17]:
x_train=splitted_dataset["x_train"]
x_train = torch.tensor(x_train, dtype=torch.float32).unsqueeze(1) / 255
x_train = (x_train - x_train.mean(dim=(2, 3), keepdim=True)) / x_train.std(
    dim=(2, 3), keepdim=True
)
x_test=splitted_dataset["x_test"]
x_test = torch.tensor(x_test, dtype=torch.float32).unsqueeze(1) / 255
x_test = (x_test - x_test.mean(dim=(2, 3), keepdim=True)) / x_test.std(
    dim=(2, 3), keepdim=True
)
x_val=splitted_dataset["x_val"]
x_val = torch.tensor(x_val, dtype=torch.float32).unsqueeze(1) / 255
x_val = (x_val - x_val.mean(dim=(2, 3), keepdim=True)) / x_val.std(
    dim=(2, 3), keepdim=True
)

In [18]:
import torch
import os
from model import UCCModel
from omegaconf import OmegaConf

model_path = "outputs\\2024-03-01\\14-44-08"
ucc_cfg = OmegaConf.load(os.path.join(model_path, ".hydra\\config.yaml"))
model = UCCModel(ucc_cfg)
state_dict = torch.load(os.path.join(model_path, "mnist_ucc_best.pth"), weights_only=False)["model_state_dict"]
model.load_state_dict(state_dict)
model.eval()

UCCModel(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): WideResidualBlock(
      (blocks): Sequential(
        (0): ResBlockZeroPadding(
          (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (skip_conv): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (2): WideResidualBlock(
      (blocks): Sequential(
        (0): ResBlockZeroPadding(
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (skip_conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
        )
      )
    )
    (3): WideResidualBlock(
      (blocks): Sequential(
        (0): ResBlockZeroPadding(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 

In [19]:
x_train_encoded = []
device = torch.device("cuda:0")
model.to(device=device)
batches = int(len(x_train)/200)
i = 0
with torch.no_grad():
    for i in range(batches):
        tensor = x_train[i*200:(i+1)*200]
        tensor = tensor.to(torch.float32).to(device)
        outputs = model.encoder(tensor)
        outputs = outputs.cpu().numpy()
        if len(x_train_encoded)==0:
            x_train_encoded = outputs
        else:
            x_train_encoded = np.concatenate((x_train_encoded, outputs))

In [20]:

with torch.no_grad():
    batches = int(len(x_val)/200)
    x_val_encoded = []
    for i in range(batches):
        tensor = x_val[i*200:(i+1)*200]
        tensor = tensor.to(torch.float32).to(device)
        outputs = model.encoder(tensor)
        outputs = outputs.cpu().numpy()
        if len(x_val_encoded)==0:
            x_val_encoded = outputs
        else:
            x_val_encoded = np.concatenate((x_val_encoded, outputs))
    batches = int(len(x_test)/200)
    x_test_encoded = []
    for i in range(batches):
        tensor = x_test[i*200:(i+1)*200]
        tensor = tensor.to(torch.float32).to(device)
        outputs = model.encoder(tensor)
        outputs = outputs.cpu().numpy()
        if len(x_test_encoded)==0:
            x_test_encoded = outputs
        else:
            x_test_encoded = np.concatenate((x_test_encoded, outputs))

In [23]:
x_test_encoded.shape



(10000, 10)

In [9]:
a = dict(splitted_dataset)

In [10]:
a["x_train"] = x_train_encoded
a["x_test"] = x_test_encoded
a["x_val"] = x_val_encoded

In [11]:
np.savez("../data/mnist/splitted_mnist_encoded_dataset.npz", **a)

In [12]:
splitted_dataset_encoded = np.load("../data/mnist/splitted_mnist_encoded_dataset.npz")

In [2]:
from dataset import MnistEncodedDataset

dataset = MnistEncodedDataset(
    num_instances=2,
    num_samples_per_class=16,
    digit_arr=list(range(0,10)),
    ucc_start=1,
    ucc_end=4,
    mode="train",
    length = 80000
)

x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
x_train shape: torch.Size([50000, 10])
50000 train samples
10000 val samples
{'class_0': array([[5],
       [2],
       [6],
       [9],
       [3],
       [0],
       [8],
       [7],
       [1],
       [4]]), 'class_1': array([[1, 8],
       [1, 2],
       [6, 9],
       [0, 1],
       [3, 6],
       [0, 4],
       [2, 6],
       [0, 6],
       [4, 9],
       [0, 5],
       [3, 8],
       [8, 9],
       [3, 5],
       [2, 9],
       [4, 6],
       [7, 9],
       [6, 7],
       [0, 8],
       [2, 4],
       [2, 3],
       [5, 8],
       [2, 5],
       [1, 3],
       [4, 5],
       [1, 4],
       [1, 5],
       [0, 2],
       [1, 9],
       [5, 6],
       [0, 3],
       [5, 7],
       [0, 7],
       [3, 7],
       [3, 4],
       [0, 9],
       [3, 9],
       [4, 7],
       [2, 8],
       [1, 7],
       [5, 9],
       [4, 8],
       [6, 8],
       [2, 7],
       [7, 8],
       [1, 6]]), 'class_2': array(

In [3]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset,
                        batch_size=32,
        num_workers=0,
        shuffle=False,
                        )

In [4]:
dataset[1]

(tensor([[0.3472, 0.3366, 0.7501, 0.4872, 0.3195, 0.6923, 0.3661, 0.7954, 0.5620,
          0.6235],
         [0.6064, 0.2411, 0.7972, 0.3209, 0.2968, 0.5668, 0.4418, 0.4927, 0.3313,
          0.6819]]),
 1)

In [5]:
for a,b in dataloader:
    print(a.shape)
    print(b.shape)
    break

torch.Size([32, 2, 10])
torch.Size([32])
