In [4]:
import os
import cv2
import numpy as np
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from tqdm import tqdm
from PIL import Image
from model import Generator
from model import Discriminator

In [5]:
# sort cropped images

input_dir = 'mask_crops'
output_dir = 'sorted_crops'
size = (128, 128) 


filenames = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
labels = [os.path.splitext(f)[0].rsplit('_', 1)[1] for f in filenames]
classes = sorted(set(labels))
class_to_idx = {c: i for i, c in enumerate(classes)}

for c in classes:
    os.makedirs(os.path.join(output_dir, c), exist_ok=True)

for f, label in zip(filenames, labels):
    img = Image.open(os.path.join(input_dir, f)).convert('RGB')
    img = img.resize(size, Image.LANCZOS)
    save_path = os.path.join(output_dir, label, f)
    img.save(save_path)

file_paths = [os.path.join(label, f) for f, label in zip(filenames, labels)]
one_hot = np.zeros((len(file_paths), len(classes)), dtype=np.uint8)
for i, label in enumerate(labels):
    one_hot[i, class_to_idx[label]] = 1

np.save(os.path.join(output_dir, 'classes.npy'), np.array(classes))
np.save(os.path.join(output_dir, 'labels.npy'), one_hot)
with open(os.path.join(output_dir, 'file_paths.json'), 'w') as fp:
    json.dump(file_paths, fp, indent=2)
with open(os.path.join(output_dir, 'class_to_idx.json'), 'w') as fp:
    json.dump(class_to_idx, fp, indent=2)

print("Classes found:", len(classes))
print("Classes:", classes)



Classes found: 325
Classes: ['Bird', 'airplane', 'antenna', 'apron', 'archbishop', 'armchair', 'armor', 'arrow', 'art', 'artist', 'assassin', 'author', 'automobile', 'avocado', 'ax', 'axe', 'bag', 'ball', 'ballerina', 'balloon', 'bandage', 'banner', 'barge', 'bark', 'barrel', 'baseball', 'basin', 'basket', 'bathrobe', 'bathtub', 'baton', 'bayonet', 'bed', 'bedding', 'bedroom', 'bell', 'belt', 'bench', 'bicycle', 'bin', 'blackboard', 'blade', 'blanket', 'bludgeon', 'boat', 'bonnet', 'book', 'boot', 'bottle', 'bouquet', 'bow', 'bowl', 'box', 'brick', 'bride', 'briefcase', 'broom', 'brush', 'bucket', 'buckle', 'bugle', 'bulb', 'button', 'cabinet', 'cage', 'cake', 'camera', 'can', 'candle', 'cane', 'cannon', 'canvas', 'capacitor', 'carbine', 'carpet', 'carriage', 'cart', 'cartridge', 'carving', 'chair', 'chandelier', 'charcoal', 'chisel', 'clip', 'clock', 'coat', 'column', 'comb', 'compass', 'costume', 'couch', 'cowboy', 'crib', 'crutch', 'cup', 'cupboard', 'curtain', 'cyclist', 'dancer', 

In [None]:
object_folder = "person"

In [7]:
# augment data if needed

sorted_crops_dir = "sorted_crops"

def augment_folder(folder_path, target_count=1000):
    os.makedirs(folder_path, exist_ok=True)

    files = [f for f in os.listdir(folder_path)]
    count = len(files)
    idx = 0
    while count < target_count:

        src_name = files[np.random.randint(0, len(files))]
        src_path = os.path.join(folder_path, src_name)
        img = cv2.imread(src_path)
        if img is None:
            continue

        h, w = img.shape[:2]

        if np.random.rand() < 0.5:
            img = cv2.flip(img, 1)

        h,w = img.shape[:2]
        rand_crop = np.random.randint(int(h*0.7), h)
        x = np.random.randint(0, w - rand_crop + 1)
        y = np.random.randint(0, h - rand_crop + 1)
        img_aug = img[y:y+rand_crop, x:x+rand_crop]
        img_aug = cv2.resize(img_aug, (w, h))

        scale = np.random.uniform(0.6, 1.0)
        img_aug = np.clip(img_aug * scale, 0, 255).astype(np.uint8)

        base, ext = os.path.splitext(src_name)
        new_name = f"aug_{base}_{idx:04d}{ext}"
        save_path = os.path.join(folder_path, new_name)

        cv2.imwrite(save_path, img_aug)
        idx += 1
        count += 1

data_root = os.path.join(sorted_crops_dir, object_folder)
augment_folder(data_root, target_count=1000)


In [10]:
# train model

batch_size = 64
image_size = 128
nc = 3 
nz = 100 
ngf = 64 
ndf = 64 
n_epochs = 200
lr = 2e-4 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*nc, [0.5]*nc),
])

class SingleFolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder, transform=None):
        self.paths = [os.path.join(folder, f) for f in os.listdir(folder)
                      if f.lower().endswith((".png",".jpg",".jpeg"))]
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

dataset = SingleFolderDataset(data_root, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                     num_workers=4, pin_memory=True)


def weights_init(m):
    classname = m.__class__.__name__
    if 'Conv' in classname:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif 'BatchNorm' in classname:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)
netG.apply(weights_init)
netD.apply(weights_init)


optimG = optim.Adam(netG.parameters(), lr=lr, betas=(0.0, 0.9))
optimD = optim.Adam(netD.parameters(), lr=lr, betas=(0.0, 0.9))


fixed_noise = torch.randn(batch_size, nz, 1,1, device=device)
os.makedirs(f"{object_folder}_samples", exist_ok=True)
os.makedirs(f"{object_folder}_pths", exist_ok=True)


for epoch in range(1, n_epochs+1):
    prog = tqdm(loader, desc=f"Epoch {epoch}/{n_epochs}", unit="batch")
    for imgs, _ in prog:
        real = imgs.to(device)
        bsz = real.size(0)
        noise = torch.randn(bsz, nz,1,1, device=device)
        fake = netG(noise)

        optimD.zero_grad()
        real_logits = netD(real)
        fake_logits = netD(fake.detach())
        lossD = torch.mean(F.relu(1. - real_logits)) + torch.mean(F.relu(1. + fake_logits))
        lossD.backward()
        optimD.step()

        optimG.zero_grad()
        gen_logits = netD(netG(noise))
        lossG = -torch.mean(gen_logits)
        lossG.backward()
        optimG.step()

        prog.set_postfix(lossD=f"{lossD.item():.4f}", lossG=f"{lossG.item():.4f}")

    if epoch % 10 == 0:
        with torch.no_grad():
            grid = utils.make_grid(netG(fixed_noise), padding=2, normalize=True)
            utils.save_image(grid, f"{object_folder}_samples/epoch_{epoch:03d}.png")
        torch.save(netG.state_dict(), f"{object_folder}_pths/netG_epoch_{epoch}.pth")
        torch.save(netD.state_dict(), f"{object_folder}_pths/netD_epoch_{epoch}.pth")


Epoch 1/200: 100%|██████████| 16/16 [00:07<00:00,  2.16batch/s, lossD=1.3070, lossG=-0.3602]
Epoch 2/200: 100%|██████████| 16/16 [00:07<00:00,  2.15batch/s, lossD=1.3621, lossG=-0.0583]
Epoch 3/200: 100%|██████████| 16/16 [00:07<00:00,  2.14batch/s, lossD=1.2860, lossG=-0.0064]
Epoch 4/200: 100%|██████████| 16/16 [00:07<00:00,  2.13batch/s, lossD=1.3379, lossG=0.1056] 
Epoch 5/200: 100%|██████████| 16/16 [00:07<00:00,  2.12batch/s, lossD=1.4781, lossG=-0.0855]
Epoch 6/200: 100%|██████████| 16/16 [00:07<00:00,  2.11batch/s, lossD=1.1644, lossG=0.5010] 
Epoch 7/200: 100%|██████████| 16/16 [00:07<00:00,  2.12batch/s, lossD=1.1134, lossG=0.8380] 
Epoch 8/200: 100%|██████████| 16/16 [00:07<00:00,  2.11batch/s, lossD=2.1786, lossG=0.1527]
Epoch 9/200: 100%|██████████| 16/16 [00:07<00:00,  2.11batch/s, lossD=1.3094, lossG=-0.1146]
Epoch 10/200: 100%|██████████| 16/16 [00:07<00:00,  2.10batch/s, lossD=1.5983, lossG=-0.3164]
Epoch 11/200: 100%|██████████| 16/16 [00:07<00:00,  2.10batch/s, lossD

In [None]:
# continue training if needed
n_epochs = 100

netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)


optimG = optim.Adam(netG.parameters(), lr=lr, betas=(0.0, 0.9))
optimD = optim.Adam(netD.parameters(), lr=lr, betas=(0.0, 0.9))


state_dict = torch.load(f"{object_folder}_pths/netG_epoch_200.pth", map_location=device)
netG.load_state_dict(state_dict)
state_dict = torch.load(f"{object_folder}_pths/netD_epoch_200.pth", map_location=device)
netD.load_state_dict(state_dict)

for epoch in range(1, n_epochs+1):
    prog = tqdm(loader, desc=f"Epoch {epoch}/{n_epochs}", unit="batch")
    for imgs, _ in prog:
        real = imgs.to(device)
        bsz = real.size(0)
        noise = torch.randn(bsz, nz,1,1, device=device)
        fake = netG(noise)

        optimD.zero_grad()
        real_logits = netD(real)
        fake_logits = netD(fake.detach())
        lossD = torch.mean(F.relu(1. - real_logits)) + torch.mean(F.relu(1. + fake_logits))
        lossD.backward()
        optimD.step()

        optimG.zero_grad()
        gen_logits = netD(netG(noise))
        lossG = -torch.mean(gen_logits)
        lossG.backward()
        optimG.step()

        prog.set_postfix(lossD=f"{lossD.item():.4f}", lossG=f"{lossG.item():.4f}")

    if epoch % 10 == 0:
        with torch.no_grad():
            grid = utils.make_grid(netG(fixed_noise), padding=2, normalize=True)
            utils.save_image(grid, f"{object_folder}_samples/epoch_{epoch+200:03d}.png")
        torch.save(netG.state_dict(), f"{object_folder}_pths/netG_epoch_{epoch+200}.pth")
        torch.save(netD.state_dict(), f"{object_folder}_pths/netD_epoch_{epoch+200}.pth")


  state_dict = torch.load(f"{object_folder}_pths/netG_epoch_200.pth", map_location=device)
  state_dict = torch.load(f"{object_folder}_pths/netD_epoch_200.pth", map_location=device)
Epoch 1/100: 100%|██████████| 16/16 [00:07<00:00,  2.21batch/s, lossD=0.5582, lossG=0.3672]
Epoch 2/100: 100%|██████████| 16/16 [00:07<00:00,  2.20batch/s, lossD=0.6361, lossG=0.8199]
Epoch 3/100: 100%|██████████| 16/16 [00:07<00:00,  2.18batch/s, lossD=0.3815, lossG=1.2214]
Epoch 4/100: 100%|██████████| 16/16 [00:07<00:00,  2.18batch/s, lossD=0.8129, lossG=1.7784]
Epoch 5/100: 100%|██████████| 16/16 [00:07<00:00,  2.17batch/s, lossD=0.5332, lossG=2.4298]
Epoch 6/100: 100%|██████████| 16/16 [00:07<00:00,  2.17batch/s, lossD=0.7262, lossG=2.4343]
Epoch 7/100: 100%|██████████| 16/16 [00:07<00:00,  2.16batch/s, lossD=0.3472, lossG=1.6165]
Epoch 8/100: 100%|██████████| 16/16 [00:07<00:00,  2.16batch/s, lossD=0.4027, lossG=1.6045]
Epoch 9/100: 100%|██████████| 16/16 [00:07<00:00,  2.15batch/s, lossD=0.4588, loss