In [1]:
from utils import *

hete = Heterogeneity()
hete.check_torch_gpu()

-------------------------------------------------
------------------ VERSION INFO -----------------
Conda Environment: torchy | Python version: 3.8.16 (default, Mar  2 2023, 03:18:16) [MSC v.1916 64 bit (AMD64)]
Torch version: 2.0.1
Torch build with CUDA? True
# Device(s) available: 1, Name(s): Quadro P520



In [110]:
# Hyperparameters and configurations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device="{}"'.format(device))

lr         = 2e-4
batch_size = 32
num_epochs = 20

input_channels_X  = 4
output_channels_X = 4

input_channels_Y  = 2
output_channels_Y = 2

device="cuda"


In [None]:
X_data, y_data = np.load('X_data.npy'), np.load('y_data.npy')
print('X_data: {} | y_data: {}'.format(X_data.shape, y_data.shape))

xn = np.moveaxis(np.moveaxis(X_data, -2, 1).reshape(2000*61,64,64,4), -1, 1)
yn = np.moveaxis(np.moveaxis(y_data, -2, 1).reshape(2000*61,64,64,2), -1, 1)
print('X_reshape: {} | y_reshape: {}'.format(xn.shape, yn.shape))

train_dataset = NumpyDataset(xn, yn)
dataloader    = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [112]:
# Initialize CycleGAN
cycle_gan = CycleGAN(input_channels_X, output_channels_Y, input_channels_Y, output_channels_X).to(device)

# Optimizers
optimizer_G   = optim.Adam(cycle_gan.parameters(), lr=lr)
optimizer_D_X = optim.Adam(cycle_gan.discriminator_X.parameters(), lr=lr)
optimizer_D_Y = optim.Adam(cycle_gan.discriminator_Y.parameters(), lr=lr)
optimizers = [optimizer_G, optimizer_D_X, optimizer_D_Y]

# Loss functions
adversarial_loss       = nn.MSELoss()
cycle_consistency_loss = nn.L1Loss()
losses = [cycle_consistency_loss, adversarial_loss]

cycle_gan.train(dataloader, num_epochs, optimizers, losses)

RuntimeError: Given groups=1, weight of size [64, 2, 4, 4], expected input[64, 4, 64, 64] to have 2 channels, but got 4 channels instead

***
# END

In [106]:
class CycleGAN(nn.Module):
    def __init__(self, input_channels_X, output_channels_Y, input_channels_Y, output_channels_X):
        super(CycleGAN, self).__init__()
        self.generator_XY    = Generator(input_channels_X, output_channels_Y)
        self.generator_YX    = Generator(input_channels_Y, output_channels_X)
        self.discriminator_X = Discriminator(input_channels_X)
        self.discriminator_Y = Discriminator(input_channels_Y)
    
    def forward(self, x, y):
        fake_Y = self.generator_XY(x)
        fake_X = self.generator_YX(y)
        reconstructed_X = self.generator_YX(fake_Y)
        reconstructed_Y = self.generator_XY(fake_X)
        return fake_Y, fake_X, reconstructed_X, reconstructed_Y
    
    def train(self, dataloader, num_epochs, optimizer_list, loss_list, verbose=True, save=False, monitor=10, device='cuda'):
        optimizer_G, optimizer_D_X, optimizer_D_Y = optimizer_list
        cycle_consistency_loss, adversarial_loss  = loss_list
        for epoch in range(num_epochs):
            for i, (X,Y) in enumerate(dataloader):
                X = X.to(device)
                Y = Y.to(device)
                # Adversarial ground truths
                valid = torch.ones(X.size(0),  1, 14, 14).to(device)
                fake  = torch.zeros(X.size(0), 1, 14, 14).to(device)
                # ------------------
                #  Train Generators
                # ------------------
                optimizer_G.zero_grad()
                # Identity loss
                identity_X = self.generator_YX(X)
                identity_Y = self.generator_XY(Y)
                loss_identity = cycle_consistency_loss(identity_X, X) + cycle_consistency_loss(identity_Y, Y)
                # Adversarial loss
                fake_Y, fake_X, reconstructed_X, reconstructed_Y = self.forward(X,Y)
                loss_GAN_XY = adversarial_loss(self.discriminator_Y(fake_Y), valid)
                loss_GAN_YX = adversarial_loss(self.discriminator_X(fake_X), valid)
                loss_GAN = loss_GAN_XY + loss_GAN_YX
                # Cycle Consistency loss
                loss_cycle_X = cycle_consistency_loss(reconstructed_X, X)
                loss_cycle_Y = cycle_consistency_loss(reconstructed_Y, Y)
                loss_cycle = loss_cycle_X + loss_cycle_Y
                # Total generator loss
                loss_G = loss_identity + loss_GAN + 10*loss_cycle
                loss_G.backward()
                optimizer_G.step()
                # ---------------------
                #  Train Discriminators
                # ---------------------
                # Discriminator X loss
                optimizer_D_X.zero_grad()
                loss_real = adversarial_loss(self.discriminator_X(X), valid)
                loss_fake = adversarial_loss(self.discriminator_X(fake_X.detach()), fake)
                loss_D_X = (loss_real + loss_fake)/2
                loss_D_X.backward()
                optimizer_D_X.step()
                # Discriminator Y loss
                optimizer_D_Y.zero_grad()
                loss_real = adversarial_loss(self.discriminator_Y(Y), valid)
                loss_fake = adversarial_loss(self.discriminator_Y(fake_Y.detach()), fake)
                loss_D_Y = (loss_real + loss_fake)/2
                loss_D_Y.backward()
                optimizer_D_Y.step()
            if (epoch+1) % monitor == 0:
                if verbose:
                    print('Epoch [{}/{}]: Generator Loss: {:.4f}, Discriminator Loss: {:.4f}'.format(
                        epoch+1, num_epochs, loss_G.item(), loss_D_X.item()+loss_D_Y.item()))
                if save:
                    with torch.no_grad():
                        fake_Y, _, _, _ = self.forward(X,Y)
                        fake_images = torch.cat([X, fake_Y], dim=0)
                        save_image(fake_images, f"generated_images/epoch_{epoch+1}.png", nrow=batch_size)
        if save:
            torch.save(self.generator_XY.state_dict(),    'generator_XY.pt')
            torch.save(self.generator_YX.state_dict(),    'generator_YX.pt')
            torch.save(self.discriminator_X.state_dict(), 'discriminator_X.pt')
            torch.save(self.discriminator_Y.state_dict(), 'discriminator_Y.pt')

In [107]:
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(Generator, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True))
        # Latent Transformer
        self.latent_transformer = SwinTransformer(512, output_channels)
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, output_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh())
    
    def forward(self, x):
        encoded = self.encoder(x)
        latent  = self.latent_transformer(encoded)
        decoded = self.decoder(latent)
        return decoded


In [108]:
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1))
    def forward(self, x):
        out = self.model(x)
        return out

In [109]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super(PatchMerging, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.norm = nn.LayerNorm(out_channels)
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = F.gelu(x)
        x = x.permute(0, 2, 3, 1)  # B, H, W, C
        B, H, W, C = x.shape
        x = x.reshape(B, H // 2, 2, W // 2, 2, C)
        x = x.permute(0, 1, 3, 2, 4, 5)  # B, H // 2, W // 2, 2, 2, C
        x = x.reshape(B, H // 2, W // 2, -1)
        x = x.permute(0, 3, 1, 2)  # B, 4C, H // 2, W // 2
        return x

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=4, stride=4)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1)  # B, H, W, C
        return x

class SwinBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4):
        super(SwinBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim))
    def forward(self, x):
        residual = x
        x    = self.norm1(x)
        x    = x.permute(1, 0, 2)  # H, B, C
        x, _ = self.attn(x, x, x)
        x    = x.permute(1, 0, 2)  # B, H, C
        x    = x + residual
        residual = x
        x    = self.norm2(x)
        x    = self.mlp(x)
        x    = x + residual
        return x

class SwinTransformer(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 img_size=64, patch_size=4, embed_dim=96, 
                 depths=[2, 2, 6, 2], num_heads=3, mlp_ratio=4):
        super(SwinTransformer, self).__init__()
        assert img_size % patch_size == 0, "Image size must be divisible by patch size"
        num_patches      = (img_size // patch_size) ** 2
        self.patch_embed = PatchEmbedding(in_channels, embed_dim)
        self.pos_embed   = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.blocks      = nn.ModuleList([SwinBlock(embed_dim, num_heads, mlp_ratio) for _ in range(sum(depths))])
        self.norm        = nn.LayerNorm(embed_dim)
        self.head        = nn.Linear(embed_dim, out_channels)

    def forward(self, x):
        x = self.patch_embed(x)
        B, N, C = x.shape
        x = x + self.pos_embed
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

In [None]:
class NumpyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).float()
    def __len__(self):
        return len(self.X)
    def __getitem__(self, index):
        img_x = self.X[index]
        img_y = self.y[index]        
        return img_x, img_y