In [7]:
# Importing Required Libraries and Modules
import os
import time
import copy
import random
import pickle
import importlib as lib
from collections import defaultdict

# Data Manipulation and Configuration
import numpy as np
import pandas as pd
import yaml

# PyTorch Utilities
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset

# Image Processing Libraries
import cv2
import matplotlib.pyplot as plt
from matplotlib import rc
import torchvision.transforms as transforms
from torchvision.transforms import Resize, ToTensor, Compose, Lambda

# Utility Functions
from utils import mnist_dataset
from utils import dataset_getter as dat

# Metrics and Evaluation
from tqdm import tqdm

# Logic Layer Implementation
from difflogic import LogicLayer, GroupSum, PackBitsTensor
import difflogic_cuda

# Configuration Management with Hydra
from hydra import initialize, compose

In [8]:
# Display Available GPU Information
print(f"Detected GPUs: {torch.cuda.device_count()}")

Detected GPUs: 4


In [12]:
# Dataset and Preprocessing Settings
crop = None  # Cropping configuration: None removes no border, specify (x,x) for cropping dimensions

# Dataset Selection
# Options: "mnist" or "fashion_mnist"
dataset_name = "mnist"

# Bits Per Pixel (Fixed at 1 for binary data in ExpLogic)
bpp = 1

# Ensuring Reproducibility Across Runs
# Seeds for Random Generators in PyTorch, NumPy, and Python
seed_value = 42
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)

# DataLoader Parameters
batch_size = 512
train_loader, test_loader, input_dim, out_dim = dat.get_dataset(
    dataset_name, batch_size=batch_size, data_dir="./data", bpp=bpp, crop=crop
)

# Analyze and Balance Class Distribution in Dataset
train_targets = train_loader.dataset.targets
test_targets = test_loader.dataset.targets

# Class-Wise Sample Count
train_class_counts = [torch.sum(train_targets == i).item() for i in range(10)]
test_class_counts = [torch.sum(test_targets == i).item() for i in range(10)]

# Determine Minimum Samples Per Class for Balanced Dataset
min_samples_train = min(train_class_counts)
min_samples_test = min(test_class_counts)

# Define Function to Trim Datasets for Balance
def balance_dataset(dataset, targets, min_samples):
    indices = []
    for class_label in range(10):
        class_indices = (targets == class_label).nonzero(as_tuple=True)[0]
        indices.extend(class_indices[:min_samples])

    # Shuffle Indices to Randomize the Dataset
    indices = torch.tensor(indices)
    shuffled_indices = indices[torch.randperm(indices.size(0))]
    
    return Subset(dataset, shuffled_indices)

# Apply Balancing to Train and Test Datasets
balanced_train_dataset = balance_dataset(train_loader.dataset, train_targets, min_samples_train)
balanced_test_dataset = balance_dataset(test_loader.dataset, test_targets, min_samples_test)

# Re-Initialize DataLoaders with Balanced Datasets
train_loader_balanced = DataLoader(
    balanced_train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True
)
test_loader_balanced = DataLoader(
    balanced_test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, drop_last=True
)

In [None]:

# Verify New Dataset Sizes
print(f"Balanced Train Dataset Size: {len(train_loader_balanced.dataset)}")
print(f"Balanced Test Dataset Size: {len(test_loader_balanced.dataset)}")

# Update Loaders and Datasets for Subsequent Use
train_loader = train_loader_balanced
test_loader = test_loader_balanced

# Visualize a Random Image from the Balanced Training Dataset
data_index = random.randint(0, len(train_loader.dataset) - 1)
random_image, _ = train_loader.dataset[data_index]

# Process Image for Display
image_shape = (20, 20) if crop else (28, 28)
processed_image = np.array([
    np.sum([random_image[(i * bpp) + j] * (2 ** (bpp - j + 1)) for j in range(bpp)])
    for i in range(image_shape[0] * image_shape[0])
]).reshape(image_shape)

# Plot the Processed Image
plt.figure()
plt.imshow(processed_image, cmap="gray")
plt.title("Sample Image from Balanced Dataset")
plt.show()