Skip to content

Out of memory when finetuning model #162

@wongchieh

Description

@wongchieh

Dear Microsoft,

I am working to try finetue Aurora model, but it went out of memory just using pseudo data and forward once,

The following is my training code:

"""
import os
from datetime import datetime,timedelta
import xarray as xr
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from aurora import AuroraPretrained, Batch, Metadata, rollout,AuroraSmallPretrained
from aurora.normalisation import locations, scales

device = torch.device(torch.device("cuda:0"))
print(f"device is {device}")

model = AuroraPretrained(
surf_vars=("2t","10u","10v","msl"),
static_vars=("lsm","z","slt"),
atmos_vars=("z","u","v","t","q"),
bf16_mode=False
).to(device)
model.load_checkpoint_local(path='ckpt/aurora-0.25-pretrained.ckpt',strict = True)

set models

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = torch.nn.HuberLoss(delta=1.0) # robust vs spikes
scaler = torch.amp.GradScaler('cuda')

pseudo data

batch = Batch(
surf_vars={k: torch.randn(1, 2, 721, 1440).to(device) for k in ("2t", "10u", "10v", "msl")},
static_vars={k: torch.randn(721, 1440).to(device) for k in ("lsm", "z", "slt")},
atmos_vars={k: torch.randn(1, 2, 4, 721, 1440).to(device) for k in ("z", "u", "v", "t", "q")},
metadata=Metadata(
lat=torch.linspace(90, -90, 721).to(device),
lon=torch.linspace(0, 360, 1440 + 1)[:-1].to(device),
time=(datetime(2020, 6, 1, 12, 0),),
atmos_levels=(100, 250, 500, 850),
),
)

model forward

model.train()
pred = model.forward(batch)

XXX
“”“

You could see that I only put the pseudo data with 0.25 degree resolution and 13 level. When running model.forward(batch), It went out of memory (over 80G).

Therefore, I wanna ask:

  1. First, am I train or comput the gradiant in a right way?
  2. How to train the model when dealing with GPU memory ?
  3. How to train use multiple GPUs

This seems a big step when finetuning aurora.

Best

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions