In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from dataset.latent_dataset import LatentDataset
from dataset.latent_image_dataset import LatentImageDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Subset
import os
from tqdm import tqdm
import wandb
from datetime import datetime
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
import socket
from mpi4py import MPI
import pickle

In [5]:
def setup_wandb(config):
    wandb.init(project="Face-diffusion", entity="megleczmate", sync_tensorboard=True, tags=["latent_classifier"])
    # get current time
    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    wandb.run.name = config['task_name'] + current_time
    wandb.run.save()
    wandb.config.update(config)
    return wandb

In [6]:
def train_model(model, criterion, optimizer, condition_config, world_size, rank, num_epochs=25, batch_size=64, wandb=None):
    # Path to CelebA dataset
    data_dir = '/mnt/g/data/latents/700/processed/big_nose'

    if rank == 0:
        wandb = setup_wandb(condition_config)

    # Load CelebA dataset with attribute labels
    image_datasets = {
        'train': LatentImageDataset(data_dir,
                                split='train',
                                target_attributes=['Big_Nose']),
        'val': LatentImageDataset(data_dir,
                                split='val',
                                target_attributes=['Big_Nose']),
    }
    
    samplers = {
        'train': DistributedSampler(image_datasets['train'], num_replicas=world_size, rank=rank, shuffle=True),
        'val': DistributedSampler(image_datasets['val'], num_replicas=world_size, rank=rank, shuffle=False)
    }

    # Data loaders
    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=batch_size, sampler=samplers['train'], num_workers=8, pin_memory=True),
        'val': DataLoader(image_datasets['val'], batch_size=batch_size, sampler=samplers['val'], num_workers=8, pin_memory=True)
    }

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

    print(f'Training dataset size: {dataset_sizes["train"]}')
    print(f'Validation dataset size: {dataset_sizes["val"]}')
    
    
    step_count = 0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        val_accuracies_per_attribute = {attr: 0.0 for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']}
        val_tps_per_attribute = {attr: 0.0 for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']}
        val_fps_per_attribute = {attr: 0.0 for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']}
        val_fns_per_attribute = {attr: 0.0 for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']}
        val_tns_per_attribute = {attr: 0.0 for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']}
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):
                labels = labels.float()                

                inputs = inputs.to(rank)
                labels = labels.float().to(rank)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):                    
                    outputs = model(inputs)



                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        step_count += 1

                        if (step_count % 10 == 0 or step_count == 1) and wandb is not None and rank == 0:
                            wandb.log({'train_loss': loss.item()})

                    else:
                        # Calculate accuracy per attribute
                        for attr_idx, attr in enumerate(condition_config['attribute_condition_config']['attribute_condition_selected_attrs']):
                            predicted_labels = torch.sigmoid(outputs) > 0.5

                            val_accuracies_per_attribute[attr] += torch.sum(predicted_labels[:, attr_idx] == labels[:, attr_idx])
                            # convert to int to avoid overflow
                            predicted_labels = predicted_labels.int()
                            labels = labels.int()

                            val_tps_per_attribute[attr] += torch.sum(predicted_labels[:, attr_idx] & labels[:, attr_idx])
                            val_fps_per_attribute[attr] += torch.sum(predicted_labels[:, attr_idx] & ~labels[:, attr_idx])
                            val_fns_per_attribute[attr] += torch.sum(~predicted_labels[:, attr_idx] & labels[:, attr_idx])
                            val_tns_per_attribute[attr] += torch.sum(~predicted_labels[:, attr_idx] & ~labels[:, attr_idx])


                running_loss += loss.item() * inputs.size(0)

                

            epoch_loss = running_loss / dataset_sizes[phase]
            print(f'{phase} Loss: {epoch_loss:.4f}')            

            if phase == 'val' and rank == 0:
                # calculate accuracy per attribute
                for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']:
                    print(f'{attr} Accuracy: {val_accuracies_per_attribute[attr] / dataset_sizes[phase]:.4f}')

                if wandb is not None:
                    wandb.log({f'{phase}_loss': epoch_loss})
                    for attr in condition_config['attribute_condition_config']['attribute_condition_selected_attrs']:
                        wandb.log({f'{attr}_accuracy': val_accuracies_per_attribute[attr] / dataset_sizes[phase]})
                        #log precision, recall, f1
                        precision = val_tps_per_attribute[attr] / (val_tps_per_attribute[attr] + val_fps_per_attribute[attr])
                        recall = val_tps_per_attribute[attr] / (val_tps_per_attribute[attr] + val_fns_per_attribute[attr])
                        f1 = 2 * precision * recall / (precision + recall)
                        wandb.log({f'{attr}_precision': precision})
                        wandb.log({f'{attr}_recall': recall})
                        wandb.log({f'{attr}_f1': f1})

            if rank == 0:
                # Save the model
                torch.save(model.state_dict(), f'celeba_cnn_latent_nose_classifier_{epoch}_700_ddpm.pth')

    if rank == 0:
        wandb.finish()

    return model


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 1st convolutional layer: input 3 channels (RGB), output 16 feature maps
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        # 2nd convolutional layer: input 16 feature maps, output 32 feature maps
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        # 3rd convolutional layer: input 32 feature maps, output 64 feature maps
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        # Max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        # Fully connected layer: 64*8*8 input size (after pooling), 256 output neurons
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        # Fully connected layer: 256 input neurons, 10 output neurons (for classification into 10 classes)
        self.fc2 = nn.Linear(256, 1)
    
    def forward(self, x):
        # Convolution + ReLU + MaxPool layer 1
        x = self.pool(F.relu(self.conv1(x)))
        # Convolution + ReLU + MaxPool layer 2
        x = self.pool(F.relu(self.conv2(x)))
        # Convolution + ReLU + MaxPool layer 3
        x = self.pool(F.relu(self.conv3(x)))
        # Flatten the tensor for fully connected layers
        x = x.view(-1, 64 * 8 * 8)
        # Fully connected layer 1 + ReLU
        x = F.relu(self.fc1(x))
        # Output layer
        x = self.fc2(x)
        # final activation function is sigmoid
        #x = torch.sigmoid(x)

        return x

condition_config = {'task_name': 'celeba_attribute_classifier_CNN_Nose',
            'condition_types': [ 'attribute' ],
            'attribute_condition_config': {
                'attribute_condition_num': 1,
                'attribute_condition_selected_attrs': ['Big_Nose',]# 'Heavy_Makeup', 'Smiling'],
                }
            }

batch_size = 256
learning_rate = 0.001
epochs = 1

# Initialize the network, optimizer, and loss function
model = SimpleCNN()

model = model.to('cuda')

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()


model = train_model(model, criterion, optimizer, condition_config, 1, 0, num_epochs=epochs, batch_size=batch_size, wandb=None)

wandb.finish()




[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmegleczmate[0m. Use [1m`wandb login --relogin`[0m to force relogin


  attributes_df = pd.read_csv(os.path.join(self.latent_path, 'attributes.csv'))
  attributes_df = pd.read_csv(os.path.join(self.latent_path, 'attributes.csv'))


Training dataset size: 1800000
Validation dataset size: 200000
Epoch 1/1
----------


100%|██████████| 7032/7032 [34:00<00:00,  3.45it/s]


train Loss: 0.0352


100%|██████████| 782/782 [03:49<00:00,  3.41it/s]


val Loss: 3.3570
Big_Nose Accuracy: 0.7425


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Big_Nose_accuracy,▁
Big_Nose_f1,▁
Big_Nose_precision,▁
Big_Nose_recall,▁
train_loss,█▇▅▂▂▂▁▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,▁

0,1
Big_Nose_accuracy,0.7425
Big_Nose_f1,0.73455
Big_Nose_precision,0.75793
Big_Nose_recall,0.71257
train_loss,2e-05
val_loss,3.35702


In [5]:
wandb.finish()

In [4]:
# load model from checkpoint
model = SimpleCNN()
model.load_state_dict(torch.load('celeba_cnn_latent_smile_classifier_0_700_ddpm.pth'))

model = model.to('cuda')

In [5]:
data_dir = '/mnt/g/data/latents/700/processed/smiling'

# Load CelebA dataset with attribute labels
val_dataset = LatentImageDataset(data_dir,
                                split='val',
                                target_attributes=['Smiling'])

val_dataloader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=8)

  attributes_df = pd.read_csv(os.path.join(self.latent_path, 'attributes.csv'))


In [6]:
print(len(val_dataset))

200000


In [7]:
# Test the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(val_dataloader, total=len(val_dataloader)):
        
        images = images.to('cuda')
        labels = labels.float().to('cuda')
        outputs = model(images)
        predicted = torch.sigmoid(outputs) > 0.5
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

100%|██████████| 2000/2000 [04:47<00:00,  6.95it/s]

Test Accuracy: 81.20%





In [54]:
predicted

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False

In [49]:
torch.max(outputs, 1)

torch.return_types.max(
values=tensor([-56.8299, -56.7637, -56.6909, -56.6829, -56.7782, -56.8402, -56.8629,
        -56.8596, -56.8605, -56.8301, -56.8279, -56.8723, -56.9383, -57.0291,
        -57.0781, -57.0882, -57.0485, -57.0069, -56.9843, -56.9552, -56.9737,
        -57.0442, -57.0857, -57.1444, -57.1491, -57.0995, -57.0527, -57.0481,
        -57.0790, -57.1067, -57.1152, -57.1335, -57.1556, -57.1923, -57.2462,
        -57.3102, -57.3600, -57.3912, -57.3886, -57.3628, -57.3947, -57.4374,
        -57.5014, -57.5159, -57.5236, -57.5355, -57.5644, -57.6168, -57.6549,
        -57.6859, -57.7002, -57.6875, -57.6609, -57.6504, -57.6415, -57.6464,
        -57.6586, -57.6732, -57.6817, -57.6635, -57.6458, -57.6260, -57.6102,
        -57.5267], device='cuda:0'),
indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], devic

In [46]:
# put activation on the output
activation = nn.Sigmoid()
print(activation(outputs))

tensor([[2.0849e-25],
        [2.2276e-25],
        [2.3958e-25],
        [2.4151e-25],
        [2.1955e-25],
        [2.0635e-25],
        [2.0173e-25],
        [2.0239e-25],
        [2.0221e-25],
        [2.0846e-25],
        [2.0890e-25],
        [1.9984e-25],
        [1.8707e-25],
        [1.7083e-25],
        [1.6267e-25],
        [1.6103e-25],
        [1.6756e-25],
        [1.7467e-25],
        [1.7866e-25],
        [1.8394e-25],
        [1.8057e-25],
        [1.6828e-25],
        [1.6144e-25],
        [1.5223e-25],
        [1.5151e-25],
        [1.5922e-25],
        [1.6686e-25],
        [1.6763e-25],
        [1.6251e-25],
        [1.5808e-25],
        [1.5674e-25],
        [1.5390e-25],
        [1.5053e-25],
        [1.4512e-25],
        [1.3750e-25],
        [1.2897e-25],
        [1.2271e-25],
        [1.1894e-25],
        [1.1924e-25],
        [1.2237e-25],
        [1.1852e-25],
        [1.1357e-25],
        [1.0653e-25],
        [1.0499e-25],
        [1.0419e-25],
        [1

In [42]:
predicted

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

In [44]:
print(correct, total)

100000 200000
