In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder: 2D Convolutions
        self.encoder = nn.Sequential(
            nn.Conv2d(9, 64, kernel_size=4, stride=2, padding=1),  # 9 channels: RGB(3) + Depth(1) + Edge(1)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # Need to add more layers
        )
        # Transition from 2D to 3D
        self.fc = nn.Linear(64 * 32 * 32, 128 * 8 * 8 * 8)  # Adjust dimensions as needed
        # Decoder: 3D Deconvolutions
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            # Need to add more layers
            nn.ConvTranspose3d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = x.view(x.size(0), 128, 8, 8, 8)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # Need to add more layers
            nn.Conv3d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.discriminator(x)
        return x.view(-1, 1).squeeze(1)


In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class RGBDEdgeDataset(Dataset):
    def __init__(self, edge_dir, depth_dir, rgb_dir, transform=None):
        self.edge_dir = edge_dir
        self.depth_dir = depth_dir
        self.rgb_dir = rgb_dir
        self.transform = transform
        
        self.edge_files = sorted(os.listdir(edge_dir))
        self.depth_files = sorted(os.listdir(depth_dir))
        self.rgb_files = sorted(os.listdir(rgb_dir))
        
    def __len__(self):
        return len(self.edge_files)
    
    def __getitem__(self, idx):
        edge_path = os.path.join(self.edge_dir, self.edge_files[idx])
        depth_path = os.path.join(self.depth_dir, self.depth_files[idx])
        rgb_path = os.path.join(self.rgb_dir, self.rgb_files[idx])
        
        edge_image = Image.open(edge_path).convert('L')  # Convert to grayscale
        depth_image = Image.open(depth_path).convert('L')  # Convert to grayscale
        rgb_image = Image.open(rgb_path).convert('RGB')  # Convert to RGB
        
        if self.transform:
            edge_image = self.transform(edge_image)
            depth_image = self.transform(depth_image)
            rgb_image = self.transform(rgb_image)
        
        return rgb_image, depth_image, edge_image

# Define a transform to resize and normalize the images
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create the dataset
edge_dir = r'D:\Jupyter\3D Construction\Dataset\redwood-3dscan\data\rgbd_extract\00037\edge'
depth_dir = r'D:\Jupyter\3D Construction\Dataset\redwood-3dscan\data\rgbd_extract\00037\depth'
rgb_dir = r'D:\Jupyter\3D Construction\Dataset\redwood-3dscan\data\rgbd_extract\00037\rgb_png'

dataset = RGBDEdgeDataset(edge_dir=edge_dir, depth_dir=depth_dir, rgb_dir=rgb_dir, transform=transform)


In [None]:
batch_size = 1  # Adjust this number based on your available memory

# Create the data loader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the generator and discriminator
generator = Generator().cuda()  # Move to GPU if available
discriminator = Discriminator().cuda()  # Move to GPU if available

# Define the optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Define the loss function
criterion = nn.BCELoss()

# Dummy target data
target = torch.randn(batch_size, 1, 32, 32, 32).cuda()

# Training loop
def train(generator, discriminator, dataloader, optimizer_G, optimizer_D, criterion, target, epochs=3):
    for epoch in range(epochs):
        for i, (rgb, depth, edge) in enumerate(dataloader):
            # Combine inputs
            rgb, depth, edge = rgb.cuda(), depth.cuda(), edge.cuda()  # Move to GPU if available
            input_data = torch.cat((rgb, depth, edge), dim=1)

            # Generate fake data
            fake_data = generator(input_data)

            # Train Discriminator
            optimizer_D.zero_grad()
            real_data = target  # Load real 3D model and move to GPU if available
            real_loss = criterion(discriminator(real_data), torch.ones(batch_size).cuda())
            fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(batch_size).cuda())
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(fake_data), torch.ones(batch_size).cuda())
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{epochs}]  Loss D: {d_loss.item()}, loss G: {g_loss.item()}")

In [None]:
# Run the training
train(generator, discriminator, dataloader, optimizer_G, optimizer_D, criterion, target, epochs=3)

In [None]:
# Save the model weights
torch.save(generator.state_dict(), 'generator_weights.pth')
torch.save(discriminator.state_dict(), 'discriminator_weights.pth')