In [1]:
import os 
os.chdir('../')
%pwd

'D:\\projects\\Project-GAN-AnimeFaces'

In [10]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class PrepareModelConfig:
    root_dir: Path
    generator_model: Path
    discriminator_model: Path
    latent_size: int

In [11]:
from AnimeFaces.constants import * 
from AnimeFaces.utils.common import read_yaml, create_directories

In [12]:
class ConfigurationManager:
    def __init__(self , config_filepath = CONFIG_FILE_PATH , params_filepath = PARAMS_FILE_PATH , schema_filepath = SCHEMA_FILE_PATH,):
        
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        self.schema = read_yaml(schema_filepath)
        
        create_directories([self.config.artifacts_root])

    def get_prepare_model_config(self) -> PrepareModelConfig:
        config = self.config.prepare_model
        create_directories([config.root_dir])
        prepare_model_config = PrepareModelConfig(
            root_dir = config.root_dir,
            generator_model = config.generator_model,
            discriminator_model = config.discriminator_model,
            latent_size = config.latent_size,
        )
        return prepare_model_config

In [13]:
from AnimeFaces import logger
from AnimeFaces.utils.device_utils import get_default_device

In [16]:
import torch.nn as nn
import torch
class PrepareModel:
    def __init__(self, config: PrepareModelConfig):
        self.config = config
        self.device = get_default_device()
    def PrepareDiscriminatorModel(self):
        discriminator = nn.Sequential(
            # in: 3 x 64 x 64
        
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 64 x 32 x 32
        
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 128 x 16 x 16
        
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 256 x 8 x 8
        
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # out: 512 x 4 x 4
        
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
            # out: 1 x 1 x 1
        
            nn.Flatten(),
            nn.Sigmoid()
        )
        torch.save(discriminator.state_dict(), self.config.discriminator_model)
    
    def PrepareGeneratorModel(self):
        generator = nn.Sequential(
            
            # latent_size x 1 x 1
            nn.ConvTranspose2d(self.config.latent_size,512,kernel_size=4,stride=1,padding=0,bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        
            # 512 x 4 x 4
            nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # 256 x 8 x 8
            nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 128 x 16 x 16
            nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # 64 x 32 x 32
            nn.ConvTranspose2d(64,3,kernel_size=4,stride=2,padding=1,bias=False),
            nn.Tanh()
            
            #3 x 64 x 64
        )
        torch.save(generator.state_dict(), self.config.generator_model)

In [17]:
try:
    config = ConfigurationManager()
    prepare_model_config = config.get_prepare_model_config()
    prepare_model = PrepareModel(config = prepare_model_config)
    prepare_model.PrepareDiscriminatorModel()
    prepare_model.PrepareGeneratorModel()
except Exception as e:
    raise e

[2023-10-02 00:42:14,295: INFO: common: yaml file: config\config.yaml loaded successfully]
[2023-10-02 00:42:14,297: INFO: common: yaml file: params.yaml loaded successfully]
[2023-10-02 00:42:14,298: INFO: common: yaml file: schema.yaml loaded successfully]
[2023-10-02 00:42:14,299: INFO: common: Directory Created: artifacts]
[2023-10-02 00:42:14,301: INFO: common: Directory Created: artifacts/prepare_base_model]
