In [1]:
from model.diffusion import Trainer
from model.unet import UNet
from model.fdnunet import FDNUNet
from model.fdnunetwithaux import FDNUNetWithAux

  from tqdm.autonotebook import tqdm


In [2]:
# model = UNet(
#     input_dim=96,
#     num_channels=2, # geometry/displacement (2)
#     num_condition_channels=3, # constraints (1) + force (2)
# )

In [3]:
# model = FDNUNet(
#     input_dim=64,
#     num_channels=2, # geometry/displacement (2)
#     # num_condition_channels=1, # geometry (1)
#     num_auxiliary_condition_channels=3, # constraints (1) + force (2)
#     num_stages=4
# )

In [4]:
model = FDNUNetWithAux(
    input_dim=64,
    image_height=64,
    image_width=64,
    num_channels=2,  # materials (2)
    # num_condition_channels=1, # geometry (1)
    num_auxiliary_condition_channels=3,  # constraints (1) + force (2)
    num_stages=4,
    range_prediction_hidden_dim=32,
    range_prediction_num_layers=3,
)

In [5]:
# Number of parameters
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

60616586


In [6]:
print(model)

FDNUNetWithAux(
  (condition_feature_extractor): ConditionFeatureExtractor(
    (pre_extractors): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): SiLU()
      (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): SiLU()
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): SiLU()
      (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): SiLU()
    )
    (extractors): ModuleList(
      (0): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): SiLU()
      )
      (1): Sequential(
        (0): Downsample(
          (downsample): Sequential(
            (0): Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2)
            (1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (1): SiLU()
      )
      (2): Sequential(
        (0): Downsample(
          (d

In [7]:
trainer = Trainer(
    model=model,
    dataset_folder="data/feadata/",
    sample_dataset_folder="data/feadata/",
    num_sample_conditions_per_plate=4,
    dataset_image_size=64,
    num_steps_per_condition=6,
    num_steps_per_sample_condition=6,
    train_batch_size=1,
    train_learning_rate=1e-4,
    num_train_steps=1,
    num_gradient_accumulation_steps=16,
    num_steps_per_milestone=1,
    # ema_steps_per_milestone=1,
    results_folder="results",
    use_batch_split_over_devices=True,
)
# trainer = Trainer(
#     model=model,
#     dataset_folder='data/feadata2500',
#     use_dataset_augmentation=True,
#     sample_dataset_folder='data/sample_1/',
#     num_sample_conditions_per_plate=1,
#     dataset_image_size=256,
#     train_batch_size=8,
#     train_learning_rate=1e-4,
#     num_train_steps=16,
#     num_gradient_accumulation_steps=2,
#     num_steps_per_milestone=1,
#     ema_steps_per_milestone=1,
#     results_folder='results',
#     use_batch_split_over_devices=True,
# )
# each effective batch (number of samples that are used for updating the loss once) is train_batch_size * num_gradient_accumulation_steps
# hence to go over the entire dataset once, we need len(dataset) / (train_batch_size * num_gradient_accumulation_steps) steps = len(train_dataloader) / num_gradient_accumulation_steps steps

In [8]:
trainer.train()

Epoch Size: 1.25 effective batches
Number of Effective Epochs: 0.8


loss: 5.4922 : 100%|██████████| 1/1 [00:06<00:00,  6.48s/it]


Training done!
