In [None]:
# based on A's work
def train(zarr_path, s3_bucket, output_dir, log_dir, checkpoint_file=None):

    start_epoch= 0
    num_epochs = 10
    batch_size = 8
    learning_rate = 0.001

    worker_rank = int(dist.get_rank())
    device = torch.device(0)
    
    #Dataset
    train_ds = WorldCoverZarrDataset(zarr_path, patch_size=224, num_classes=11)
    val_ds   = WorldCoverZarrDataset(zarr_path, patch_size=224, num_classes=11)

    sampler_train = DistributedSampler(train_ds)
    sampler_val   = DistributedSampler(val_ds)

    traingen = DataLoader(train_ds, batch_size=batch_size, sampler=sampler_train, shuffle=False)
    valgen   = DataLoader(val_ds,   batch_size=batch_size, sampler=sampler_val, shuffle=False)
    
    val_loss_min = np.Inf
    
    # pass to GPU device -> wrapped in DDP then it can communicate with the other workers
    model = PrithviSegmentation(num_classes=11)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model = DDP(model, device_ids=[0])
    model_without_ddp = model.module
    
    loss_function = nn.CrossEntropyLoss(ignore_index=255)

    if checkpoint_file:
        checkpoint_connection = S3Checkpoint(region='ap-southeast-2')
        # load best saved model checkpoint from previous commit (if present)
        with checkpoint_connection.reader(f"s3://{bucket}/{userid}/{project_name}/"+checkpoint_file) as reader:
            checkpoint = torch.load(reader)
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        val_loss_min =checkpoint['val_loss_min']
        
    for epoch in range(start_epoch, num_epochs):
        log={}
        log.update({"worker_rank":worker_rank, "epoch":epoch, "start_time":datetime.now().isoformat()})
        
        train_accuracy = 0
        train_loss = 0
        val_accuracy = 0
        val_loss = 0
        
        sampler_train.set_epoch(epoch)
        sampler_val.set_epoch(epoch)
        # #Training
        model.train()
        for inputs, labels in traingen:
            inputs = inputs.to(device)# (B, 6, 3, 224, 224)
            labels = labels.to(device)# (B, 224, 224)
    
            optimizer.zero_grad()
            outputs = model(inputs)# (B, 11, 224, 224)
            # pixel accuracy (ignoring 255)
            with torch.no_grad():
                preds = outputs.argmax(dim=1)# (B, H, W)
                valid = labels != 255
                correct = (preds[valid] == labels[valid]).sum().item()
                total = valid.sum().item()
                acc = correct / max(total, 1)
    
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
    
            train_loss += loss.item()
            train_accuracy += acc
        
        log.update({
            "train_loss": float(train_loss),
            "train_accuracy": float(train_accuracy),
            "end_time_train": datetime.now().isoformat(),
        })
    
        # --- VAL ---
        model.eval()
        with torch.no_grad():
            for inputs, labels in valgen:
                inputs = inputs.to(device)
                labels = labels.to(device)
    
                outputs = model(inputs)
    
                preds = outputs.argmax(dim=1)
                valid = labels != 255
                correct = (preds[valid] == labels[valid]).sum().item()
                total = valid.sum().item()
                acc = correct / max(total, 1)
    
                loss = loss_function(outputs, labels)
                val_loss += loss.item()
                val_accuracy += acc
        
        log.update({
            "val_loss": float(val_loss),
            "val_accuracy": float(val_accuracy),
            "end_time_val": datetime.now().isoformat(),
        })
    
        dist.barrier()
        if worker_rank == 0:
            log_file = log_dir + f"log_epoch_{epoch}.json"
            boto3.client("s3").put_object(
                Body=json.dumps(log),
                Bucket=s3_bucket,
                Key=log_file,
            )
            if val_loss < val_loss_min:
                val_loss_min = val_loss
                output_file=output_dir+ 'model_{}.pth'.format(epoch)
                checkpoint_connection = S3Checkpoint(region='ap-southeast-2')
                checkpoint = {'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(),'epoch': epoch, 'val_loss_min':val_loss_min}
                with checkpoint_connection.writer(f"s3://{bucket}/{userid}/{project_name}/"+output_file) as writer:
                    torch.save(checkpoint, writer)

    return