In [None]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn import manifold
from torch.utils.data import DataLoader
import numpy as np
device='cpu'

In [None]:
# Define a Convolutional Autoencoder model
class ConvAutoencoder(nn.Module):
    def __init__(self, embedding_dim=8):
        super(ConvAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=0, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),
              nn.ReLU(),
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=0, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=0),
            nn.Sigmoid(),
        )
        self.embedding_encoder = nn.Sequential(
            nn.Linear(128 * 2 * 2, 512),
            nn.ReLU(),
            nn.Linear(512, embedding_dim)
            )
        self.embedding_decoder = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 128 * 2 * 2)
            )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten        
        embedding = self.embedding_encoder(x)
        #embedding = embedding/torch.norm(embedding,p=2,dim=-1,keepdim=True)
        xhat = self.decoder(self.embedding_decoder(embedding).view(x.size(0), 128, 2, 2))
        return xhat, embedding

In [None]:
batch_size=128
transform = transforms.ToTensor()
emnist_data = torchvision.datasets.EMNIST(root='./data', train=True,split='byclass', download=True, transform=transform)
emnist_test_loader = DataLoader(emnist_data, batch_size=batch_size, shuffle=False)
# Load your trained autoencoder
autoencoder = ConvAutoencoder(embedding_dim=6).to(device)
autoencoder.load_state_dict(torch.load('AE_EMNIST_0.pt'))  # Load your model
autoencoder.eval()

In [None]:
autoencoder.eval()
test_encode, test_targets = [], []
for x_val, y_val in emnist_test_loader:
    x_val = x_val.to(device)

    xhat,zhat = autoencoder(x_val)
    # yhat = model.decoder(zhat)
    test_encode.append(zhat.detach().numpy())
    test_targets.append(y_val.detach().numpy())
X_list=np.vstack(test_encode)
label_list=np.concatenate(test_targets)
EMNIST=(X_list,label_list)
torch.save(EMNIST,'/home/baly/projects/linear_pgw/pu_learning/data/EMNIST.pt')
print('Embeddings are calculated')

In [None]:

test_encode = torch.cat(test_encode).cpu().numpy()
test_targets = torch.cat(test_targets).cpu().numpy()

# Select a subset of classes
selected_classes = np.random.randint(0,20,10)  # Replace with your chosen class indices
mask = np.isin(test_targets, selected_classes)

# Filter the data
z_subset = test_encode[mask]
Y_subset = test_targets[mask]

# Apply t-SNE to the subset
tsne = manifold.TSNE(n_components=2, init="pca", random_state=0)
X_2d_subset = tsne.fit_transform(z_subset)

In [None]:
import numpy as np

In [None]:
fig = plt.figure(figsize=(10, 10))

# Iterate over each class in the selected_classes and plot them separately
for class_index in np.unique(Y_subset):
    # Select data points that belong to the current class
    indices = Y_subset == class_index
    plt.scatter(X_2d_subset[indices, 0], X_2d_subset[indices, 1], label=f'Class {class_index}', s=1)


# Adding legend
plt.legend()

# Show the plot
plt.show()