In [None]:
import sys
sys.path.append('/content/drive/MyDrive/D4D')

In [None]:
!pip install warmup_scheduler
!pip install wandb

In [None]:
from trainer import Trainer, ConfigParser
from model import NoiseEstimationClip, NoiseEstimationCLIP_pretrained
from dataset import NoiseEstimationDataset, NoiseEstimationValidationDataset, create_dataloaders
from types import SimpleNamespace

In [None]:
def flatten_namespace(nested_namespace):
    flat_namespace = SimpleNamespace()
    def add_attributes(ns):
        for key, value in vars(ns).items():
            if isinstance(value, SimpleNamespace):
                add_attributes(value)
            else:
                setattr(flat_namespace, key, value)

    add_attributes(nested_namespace)
    return flat_namespace

In [None]:
config = ConfigParser.parse_yaml('/content/drive/MyDrive/D4D/config.yaml')
config = flatten_namespace(config)

In [None]:
dataset = NoiseEstimationDataset(image_dir=config.valid_dir,
                                  clean_image=config.image_dir,
                                  img_size=config.image_size,
                                  specific_timesteps=config.specific_timesteps,
                                  saved_all_data_first=True,
                                  num_cores=config.num_cores)

In [None]:
valid_dataset = NoiseEstimationValidationDataset(image_dir=config.valid_dir,
                                        clean_image=config.valid_image,
                                        img_size=config.image_size,
                                        specific_timesteps=config.specific_timesteps,
                                        saved_all_data_first=True,
                                        num_cores=config.num_cores)
print(valid_dataset.specified_timesteps)

In [None]:
dataloader = create_dataloaders(dataset,
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

In [None]:
valid_dataloader = create_dataloaders(valid_dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_workers)

In [None]:
if config.text_image:
  print('Using text image model')
  model = NoiseEstimationCLIP_pretrained(model_name=config.clip_pretrained_model_name)
else:
  model = NoiseEstimationClip(d_model=config.d_model,
                              in_channels=config.in_channels,
                              image_size=config.image_size,
                              patch_size=config.patch_size,
                              num_heads=config.num_heads,
                              num_layers=config.num_layers,
                              final_embedding=config.final_embedding_dim)

In [None]:
trainer = Trainer(model, dataloader, valid_dataloader, config)

In [None]:
trainer.train()