In [2]:
import unittest
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision.transforms import Resize, Grayscale, ToTensor
from sklearn.model_selection import train_test_split
from PIL import Image
import os
import json

def run_test_suite(test_class):
    suite = unittest.TestLoader().loadTestsFromTestCase(test_class)
    unittest.TextTestRunner(verbosity=2).run(suite)


In [3]:
Dataset

torch.utils.data.dataset.Dataset

Storm Dataset

In [4]:
class StormDataset(Dataset):
    """
    Custom dataset class for storm data.

    Args:
        root_dir (str): Root directory containing storm data.
        storm_id (str or list of str): Storm IDs for the dataset.
        sequence_length (int): Length of sequences to extract.
        split (str): Dataset split ('train' or 'test').
        test_size (float): Proportion of data to use for testing (if split is 'train').

    Attributes:
        root_dir (str): Root directory containing storm data.
        sequence_length (int): Length of sequences to extract.
        transform (torchvision.transforms.Compose): Image transformations.
        storm_id (str or list of str): Storm IDs for the dataset.
        sequences (list): List of sequences containing images, features, and labels.

    Methods:
        _load_and_process_data(): Loads and processes storm data.
    """

    def __init__(self, root_dir, storm_id, sequence_length=15, split='train', test_size=0.2):
        self.root_dir = root_dir
        self.sequence_length = sequence_length
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor()
        ])
        self.storm_id = storm_id
        self.sequences = []
        self._load_and_process_data()

        # Split the dataset into train and test sets
        train_sequences, test_sequences = train_test_split(self.sequences, test_size=test_size, random_state=42)
        self.sequences = train_sequences if split == 'train' else test_sequences

    def _load_and_process_data(self):
        """
        Load and process storm data, extracting sequences of images, features, and labels.
        """
        time_features = []

        storms = self.storm_id

        for storm_id in storms:
            storm_path = os.path.join(self.root_dir, storm_id)
            all_files = os.listdir(storm_path)

            temp_images = []
            temp_features = []
            temp_labels = []

            for file in sorted(all_files):
                if file.endswith('.jpg'):
                    image_path = os.path.join(storm_path, file)
                    image = Image.open(image_path)
                    temp_images.append(self.transform(image))
                elif file.endswith('_features.json') or file.endswith('_label.json'):
                    with open(os.path.join(storm_path, file), 'r') as f:
                        data = json.load(f)
                        if file.endswith('_features.json'):
                            temp_features.append([float(data['relative_time']), float(data['ocean'])])
                        else:
                            temp_labels.append(float(data['wind_speed']))
            max_relative_time = max([f[0] for f in temp_features])
            for feature in temp_features:
                feature[0] /= max_relative_time
            time_features.extend([f[0] for f in temp_features])

            for i in range(len(temp_images) - self.sequence_length):
                self.sequences.append({
                    'images': temp_images[i:i + self.sequence_length],
                    'features': temp_features[i:i + self.sequence_length],
                    'labels': temp_labels[i:i + self.sequence_length]
                })

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        return {
            'images': torch.stack(sequence['images']),
            'features': torch.tensor(sequence['features'], dtype=torch.float),
            'labels': torch.tensor(sequence['labels'], dtype=torch.float)
        }

In [5]:
import torch.nn as nn
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()
        self.input_channel = input_dim
        self.hidden_channel = hidden_dim
        self.kernel_sz = kernel_size
        self.pad = kernel_size[0] // 2, kernel_size[1] // 2
        self.use_bias = bias
        self.conv = nn.Conv2d(in_channels=self.input_channel + self.hidden_channel,
                              out_channels=4 * self.hidden_channel,
                              kernel_size=self.kernel_sz,
                              padding=self.pad,
                              bias=self.use_bias)
    def forward(self, input_tensor, cur_state):
        h_current, c_current = cur_state
        combined = torch.cat([input_tensor, h_current], dim=1)
        conv_result = self.conv(combined)
        cc_inputgate, cc_forgetgate, cc_outputgate, cc_cellgate = torch.split(conv_result, self.hidden_channel, dim=1)
        input_gate = torch.sigmoid(cc_inputgate)
        forget_gate = torch.sigmoid(cc_forgetgate)
        output_gate = torch.sigmoid(cc_outputgate)
        cell_gate = torch.tanh(cc_cellgate)
        c_next = forget_gate * c_current + input_gate * cell_gate
        h_next = output_gate * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_channel, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_channel, height, width, device=self.conv.weight.device))

class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()
        self._check_kernel_size_consistency(kernel_size)
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers
        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))
        self.cell_list = nn.ModuleList(cell_list)
    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first:
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
        b, _, _, h, w = input_tensor.size()
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))
        layer_output_list = []
        last_state_list = []
        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor
        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output
            layer_output_list.append(layer_output)
            last_state_list.append([h, c])
        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]
        return layer_output_list, last_state_list

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states



In [6]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.conv_layers(x)

class StormGenerator(nn.Module):
    def __init__(self):
        super(StormGenerator, self).__init__()
        self.encoder = SimpleCNN()
        self.conv_lstm = ConvLSTM(input_dim=128, hidden_dim=[64, 32], kernel_size=(3, 3), num_layers=2, batch_first=True)
        self.decoder = nn.Sequential(
          nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
          nn.ReLU(inplace=True),
          nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),
          nn.ReLU(inplace=True),
          nn.ConvTranspose2d(8, 1, kernel_size=4, stride=2, padding=1),
          nn.Tanh()
      )
    def forward(self, input_imgs):
        batch_size, sequence_len, c, h, w = input_imgs.size()
        # 修改这里，使用 .reshape() 而不是 .view()
        c_input = input_imgs.reshape(batch_size * sequence_len, c, h, w)
        c_output = self.encoder(c_input)
        # 注意这里可能也需要使用 .reshape()，取决于后续操作
        c_output = c_output.view(batch_size, sequence_len, -1, h // 8, w // 8)
        conv_lstm_out, _ = self.conv_lstm(c_output)
        conv_lstm_out = conv_lstm_out[0][:, -1, :, :, :]
        output_image = self.decoder(conv_lstm_out)
        return output_image


model = StormGenerator()


In [8]:
# Function to retrieve a list of storm IDs from a specified directory
def get_storm_ids(root_dir):
    """
    Get a list of storm IDs from the specified directory.

    Args:
        root_dir (str): The root directory containing storm data folders.

    Returns:
        list: A list of storm IDs.
    """
    storm_ids = [name for name in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, name))]
    return storm_ids

# Set the root directory for storm data
root_dir = '/Users/mk1923/Downloads/Selected_Storms_curated_to_zip'

# Initialize the device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the list of storm IDs using the get_storm_ids function
storms = get_storm_ids(root_dir)


In [13]:
# Function to retrieve a list of storm IDs from a specified directory
def get_storm_ids(root_dir):
    """
    Get a list of storm IDs from the specified directory.

    Args:
        root_dir (str): The root directory containing storm data folders.

    Returns:
        list: A list of storm IDs.
    """
    storm_ids = [name for name in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, name))]
    return storm_ids

# Set the root directory for storm data
root_dir = '/Users/mk1923/Downloads/Selected_Storms_curated_to_zip'

# Initialize the device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the list of storm IDs using the get_storm_ids function
storms = get_storm_ids(root_dir)


Tests for Dataset

In [14]:
class StormDatasetTest(unittest.TestCase):
    def setUp(self):
        self.root_dir = '/Users/mk1923/Downloads/Selected_Storms_curated_to_zip'
        # Dynamically retrieve storm IDs
        self.storm_ids = get_storm_ids(self.root_dir)  # Using the function to get storm IDs
        self.sequence_length = 11

        # Testing with the first retrieved storm ID for simplicity; expand as needed
        storm_id_to_test = self.storm_ids[:1] if self.storm_ids else ['blq']

        # Create the dataset
        self.dataset = StormDataset(root_dir=self.root_dir, storm_id=storm_id_to_test, sequence_length=self.sequence_length, split='train')

        self.train_dataset = StormDataset(root_dir=self.root_dir, storm_id=storm_id_to_test, sequence_length=self.sequence_length, split='train')
        self.test_dataset = StormDataset(root_dir=self.root_dir, storm_id=storm_id_to_test, sequence_length=self.sequence_length, split='test')


    def test_dataset_loading(self):
        # Test a few samples for shape checks to ensure broader coverage
        for idx in range(min(len(self.dataset), 5)):  # Test up to 5 samples
            sample = self.dataset[idx]
            self.assertEqual(sample['images'].shape, (self.sequence_length, 1, 224, 224), "Input sequence shape is incorrect for sample {}".format(idx))
            self.assertEqual(sample['features'].shape, (self.sequence_length, 2), "Features shape is incorrect for sample {}".format(idx))
            self.assertEqual(sample['labels'].shape, (self.sequence_length,), "Labels shape is incorrect for sample {}".format(idx))

    def test_unique_storms_count(self):
        # Testing against dynamic retrieval of storm IDs
        expected_count = len(self.storm_ids)
        found_storms = len(set(self.storm_ids))  # Assuming each directory name is unique
        self.assertEqual(expected_count, found_storms, "The number of unique storms does not match the expected count")

    def test_image_transformations(self):
        # Test the first image in the sequence for a few samples
        for idx in range(min(len(self.dataset), 5)):  # Test up to 5 samples
            sample = self.dataset[idx]
            image = sample['images'][0]  # First image
            self.assertEqual(image.shape[0], 1, "Image should be converted to grayscale for sample {}".format(idx))
            self.assertEqual(image.shape[1:], (224, 224), "Image size should be resized to 224x224 for sample {}".format(idx))

    def test_feature_normalization(self):
        # Test feature normalization for a few samples
        for idx in range(min(len(self.dataset), 5)):  # Test up to 5 samples
            sample = self.dataset[idx]
            features = sample['features']
            normalized_values = features[:, 0]  # Assuming first feature is 'relative_time'
            self.assertTrue(torch.all((0 <= normalized_values) & (normalized_values <= 1)), "Feature values should be normalized between 0 and 1 for sample {}".format(idx))

    def test_label_processing(self):
        # Test label processing for a few samples
        for idx in range(min(len(self.dataset), 5)):  # Test up to 5 samples
            sample = self.dataset[idx]
            labels = sample['labels']
            self.assertTrue(torch.is_tensor(labels), "Labels should be a PyTorch tensor for sample {}".format(idx))


    def test_image_normalization(self):
        # Test that image pixel values are normalized between 0 and 1.
        sample = self.dataset[0]
        images = sample['images']
        self.assertTrue(torch.all(images <= 1) and torch.all(images >= 0), "Image pixel values should be normalized between 0 and 1.")

    def test_sequence_continuity(self):
        # Test for sequence continuity in time features, if applicable.
        sample = self.dataset[0]
        features = sample['features']
        # Ensure continuity if your features include time as the first element.
        time_differences = torch.diff(features[:, 0])
        self.assertTrue(torch.all(time_differences > 0), "Time features should be continuously increasing.")

    def test_data_transformation_effectiveness(self):
        # Test the effectiveness of data transformations.
        sample = self.dataset[0]
        images = sample['images']
        self.assertEqual(images.shape[1:], (1, 224, 224), "Transformed images should have the correct shape.")

    def test_dataset_length_and_indexing(self):
        # Test the reported length of the dataset and ability to index into it.
        self.assertTrue(len(self.dataset) > 0, "Dataset should report a non-zero length.")
        try:
            sample = self.dataset[len(self.dataset) - 1]
        except IndexError:
            self.fail("Indexing the last element of the dataset raised an IndexError.")



# Run StormDatasetTest
run_test_suite(StormDatasetTest)


test_data_transformation_effectiveness (__main__.StormDatasetTest.test_data_transformation_effectiveness) ... ok
test_dataset_length_and_indexing (__main__.StormDatasetTest.test_dataset_length_and_indexing) ... ok
test_dataset_loading (__main__.StormDatasetTest.test_dataset_loading) ... ok
test_feature_normalization (__main__.StormDatasetTest.test_feature_normalization) ... ok
test_image_normalization (__main__.StormDatasetTest.test_image_normalization) ... ok
test_image_transformations (__main__.StormDatasetTest.test_image_transformations) ... ok
test_label_processing (__main__.StormDatasetTest.test_label_processing) ... ok
test_sequence_continuity (__main__.StormDatasetTest.test_sequence_continuity) ... ok
test_unique_storms_count (__main__.StormDatasetTest.test_unique_storms_count) ... ok

----------------------------------------------------------------------
Ran 9 tests in 10.254s

OK


In [17]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, Subset

root_dir = '/Users/mk1923/Downloads/Selected_Storms_curated_to_zip'

train_dataset = StormDataset(root_dir, storm_id=storms, sequence_length=11, split='train')
val_dataset = StormDataset(root_dir, storm_id=storms, sequence_length=11, split='test')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=3)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=3)

Tests for Datalaoder 


In [None]:
class DataLoaderTest(unittest.TestCase):
    def setUp(self):
        # Setup common parameters for tests
        self.root_dir = '/content/drive/MyDrive/Selected_Storms_curated'
        self.storm_id = ['bkh']
        self.sequence_length = 11
        self.batch_size = 32
        self.train_dataset = StormDataset(root_dir=self.root_dir, storm_id=self.storm_id, sequence_length=self.sequence_length, split='train')
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def test_dataloader_output(self):
        batch = next(iter(self.train_loader))
        images, features, labels = batch['images'], batch['features'], batch['labels']

        self.assertEqual(images.shape, (self.batch_size, self.sequence_length, 1, 224, 224), "Batched images shape is incorrect")
        self.assertEqual(features.shape, (self.batch_size, self.sequence_length, 2), "Batched features shape is incorrect")
        self.assertEqual(labels.shape, (self.batch_size, self.sequence_length), "Batched labels shape is incorrect")

    def test_total_images_processed(self):
        total_images = 0
        for batch in self.train_loader:
            total_images += batch['images'].size(0) * batch['images'].size(1)  # batch size * sequence length

        expected_total_images = len(self.train_dataset) * self.sequence_length
        self.assertEqual(total_images, expected_total_images, f"Total processed images should be {expected_total_images}, got {total_images}")

# Run StormDatasetTest
run_test_suite(DataLoaderTest)


test_dataloader_output (__main__.DataLoaderTest) ... ok
test_total_images_processed (__main__.DataLoaderTest) ... ok

----------------------------------------------------------------------
Ran 2 tests in 27.754s

OK


Model Architecture

In [18]:
import torch.nn as nn
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()
        self.input_channel = input_dim
        self.hidden_channel = hidden_dim
        self.kernel_sz = kernel_size
        self.pad = kernel_size[0] // 2, kernel_size[1] // 2
        self.use_bias = bias
        self.conv = nn.Conv2d(in_channels=self.input_channel + self.hidden_channel,
                              out_channels=4 * self.hidden_channel,
                              kernel_size=self.kernel_sz,
                              padding=self.pad,
                              bias=self.use_bias)
    def forward(self, input_tensor, cur_state):
        h_current, c_current = cur_state
        combined = torch.cat([input_tensor, h_current], dim=1)
        conv_result = self.conv(combined)
        cc_inputgate, cc_forgetgate, cc_outputgate, cc_cellgate = torch.split(conv_result, self.hidden_channel, dim=1)
        input_gate = torch.sigmoid(cc_inputgate)
        forget_gate = torch.sigmoid(cc_forgetgate)
        output_gate = torch.sigmoid(cc_outputgate)
        cell_gate = torch.tanh(cc_cellgate)
        c_next = forget_gate * c_current + input_gate * cell_gate
        h_next = output_gate * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_channel, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_channel, height, width, device=self.conv.weight.device))

class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()
        self._check_kernel_size_consistency(kernel_size)
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers
        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))
        self.cell_list = nn.ModuleList(cell_list)
    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first:
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
        b, _, _, h, w = input_tensor.size()
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))
        layer_output_list = []
        last_state_list = []
        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor
        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output
            layer_output_list.append(layer_output)
            last_state_list.append([h, c])
        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]
        return layer_output_list, last_state_list

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states



In [19]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.conv_layers(x)

class StormGenerator(nn.Module):
    def __init__(self):
        super(StormGenerator, self).__init__()
        self.encoder = SimpleCNN()
        self.conv_lstm = ConvLSTM(input_dim=128, hidden_dim=[64, 32], kernel_size=(3, 3), num_layers=2, batch_first=True)
        self.decoder = nn.Sequential(
          nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
          nn.ReLU(inplace=True),
          nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),
          nn.ReLU(inplace=True),
          nn.ConvTranspose2d(8, 1, kernel_size=4, stride=2, padding=1),
          nn.Tanh()
      )
    def forward(self, input_imgs):
        batch_size, sequence_len, c, h, w = input_imgs.size()
        # 修改这里，使用 .reshape() 而不是 .view()
        c_input = input_imgs.reshape(batch_size * sequence_len, c, h, w)
        c_output = self.encoder(c_input)
        # 注意这里可能也需要使用 .reshape()，取决于后续操作
        c_output = c_output.view(batch_size, sequence_len, -1, h // 8, w // 8)
        conv_lstm_out, _ = self.conv_lstm(c_output)
        conv_lstm_out = conv_lstm_out[0][:, -1, :, :, :]
        output_image = self.decoder(conv_lstm_out)
        return output_image


model = StormGenerator()


Tests for Model Archeticture Architecture

In [20]:
import torch
import unittest
import torch.nn as nn

class TestModelArchitecture(unittest.TestCase):
    def setUp(self):
        # Set up common variables and model instances for tests
        self.input_dim = 128
        self.hidden_dim = [64, 32]
        self.kernel_size = (3, 3)
        self.num_layers = 2
        self.batch_size = 1
        self.seq_len = 5
        self.channels = 1
        self.height = self.width = 224  # Assuming square input images

        # Initialize models
        self.conv_lstm_cell = ConvLSTMCell(self.input_dim, self.hidden_dim[0], self.kernel_size, True)
        self.conv_lstm = ConvLSTM(self.input_dim, self.hidden_dim, self.kernel_size, self.num_layers, batch_first=True)
        self.simple_cnn = SimpleCNN()
        self.storm_generator = StormGenerator()

    def test_initialization(self):
        # Test for ConvLSTMCell
        self.assertIsInstance(self.conv_lstm_cell.conv, nn.Conv2d, "ConvLSTMCell Conv2d initialization failed")

        # Test for ConvLSTM
        self.assertEqual(len(self.conv_lstm.cell_list), self.num_layers, "ConvLSTM cell_list initialization failed")

        # Test for SimpleCNN
        self.assertTrue(isinstance(self.simple_cnn.conv_layers, nn.Sequential), "SimpleCNN Sequential initialization failed")

        # Test for StormGenerator
        self.assertTrue(isinstance(self.storm_generator.encoder, SimpleCNN), "StormGenerator encoder initialization failed")
        self.assertTrue(isinstance(self.storm_generator.conv_lstm, ConvLSTM), "StormGenerator ConvLSTM initialization failed")
        self.assertTrue(isinstance(self.storm_generator.decoder, nn.Sequential), "StormGenerator decoder initialization failed")

    def test_forward_pass(self):
        # Simulate input for forward pass
        input_tensor = torch.randn(self.batch_size, self.seq_len, self.channels, self.height, self.width)
        output = self.storm_generator(input_tensor)
        self.assertEqual(output.shape, (self.batch_size, self.channels, self.height, self.width), "StormGenerator forward pass output shape is incorrect")

    def test_layer_consistency(self):
        # Verifying the consistency of layers within ConvLSTM and SimpleCNN components

        # SimpleCNN layer tests
        conv_layers_count = sum(1 for _ in filter(lambda layer: isinstance(layer, nn.Conv2d), self.simple_cnn.conv_layers))
        self.assertEqual(conv_layers_count, 3, "SimpleCNN does not contain the expected number of Conv2d layers")

        # ConvLSTM layer tests
        for i, cell in enumerate(self.conv_lstm.cell_list):
            self.assertTrue(isinstance(cell, ConvLSTMCell), f"Layer {i} in ConvLSTM is not a ConvLSTMCell")
            self.assertEqual(cell.kernel_sz, self.kernel_size, f"ConvLSTMCell {i} kernel size is incorrect")
            self.assertTrue(cell.use_bias, f"ConvLSTMCell {i} bias is not being used as expected")

    def test_matrix_shapes_compatibility(self):
        model = StormGenerator()
        mock_input = torch.randn(32, 11, 1, 224, 224)  # Adjust mock input to match your model's expected input shape
        try:
            model(mock_input)  # Perform a forward pass with mock input
        except RuntimeError as e:
            if "size mismatch" in str(e):
                self.fail(f"Matrix shape conflict encountered: {e}")

    def test_simple_cnn_output_shape(self):
        input_tensor = torch.randn(self.batch_size, self.channels, self.height, self.width)
        output = self.simple_cnn(input_tensor)
        expected_output_shape = (self.batch_size, 128, self.height // 8, self.width // 8)  # Adjust based on expected output
        self.assertEqual(output.shape, expected_output_shape, "SimpleCNN output shape is incorrect")

    def test_simple_cnn_relu_activations(self):
        input_tensor = torch.randn(self.batch_size, self.channels, self.height, self.width)
        output = self.simple_cnn(input_tensor)
        self.assertTrue(torch.all(output >= 0), "SimpleCNN ReLU activations not applied correctly")

    def test_simple_cnn_parameter_count(self):
        num_params = sum(p.numel() for p in self.simple_cnn.parameters())
        expected_params = 92672  # Adjust based on the expected count
        self.assertEqual(num_params, expected_params, "SimpleCNN parameter count is incorrect")

    def test_storm_generator_output_range(self):
        input_tensor = torch.randn(self.batch_size, self.seq_len, self.channels, self.height, self.width)
        output = self.storm_generator(input_tensor)
        self.assertTrue(torch.all(output >= -1) and torch.all(output <= 1), "StormGenerator output range is incorrect")


    def test_gradient_calculation(self):
        input_tensor = torch.randn(self.batch_size, self.seq_len, self.channels, self.height, self.width, requires_grad=True)
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(self.storm_generator.parameters(), lr=0.01)

        for _ in range(10):
            optimizer.zero_grad()
            output = self.storm_generator(input_tensor)
            loss = loss_fn(output, torch.randn(self.batch_size, self.channels, self.height, self.width))
            loss.backward()
            optimizer.step()

            self.assertTrue(all(param.grad is not None for param in self.storm_generator.parameters()), "Gradient calculation failed")

    def test_input_type_handling(self):
        input_tensor = torch.randn(self.batch_size, self.seq_len, self.channels, self.height, self.width, dtype=torch.float64)
        output = self.storm_generator(input_tensor.float())  # Convert input tensor to float32
        self.assertEqual(output.dtype, torch.float32, "StormGenerator output dtype is incorrect")



# Run StormDatasetTest
run_test_suite(TestModelArchitecture)


test_forward_pass (__main__.TestModelArchitecture.test_forward_pass) ... ok
test_gradient_calculation (__main__.TestModelArchitecture.test_gradient_calculation) ... ok
test_initialization (__main__.TestModelArchitecture.test_initialization) ... ok
test_input_type_handling (__main__.TestModelArchitecture.test_input_type_handling) ... ok
test_layer_consistency (__main__.TestModelArchitecture.test_layer_consistency) ... ok
test_matrix_shapes_compatibility (__main__.TestModelArchitecture.test_matrix_shapes_compatibility) ... ok
test_simple_cnn_output_shape (__main__.TestModelArchitecture.test_simple_cnn_output_shape) ... ok
test_simple_cnn_parameter_count (__main__.TestModelArchitecture.test_simple_cnn_parameter_count) ... ok
test_simple_cnn_relu_activations (__main__.TestModelArchitecture.test_simple_cnn_relu_activations) ... ok
test_storm_generator_output_range (__main__.TestModelArchitecture.test_storm_generator_output_range) ... ok

-----------------------------------------------------

Train Function

In [24]:
!pip install pytorch_msssim





In [26]:
from pytorch_msssim import ssim
import torch
import matplotlib.pyplot as plt

def train(model, train_loader, optimizer, device, num_epochs):

    model.train()  # Set the model to training mode
    loss_history = []  # Initialize a list to store the average loss per epoch

    for epoch in range(num_epochs):  # Assuming num_epochs is defined
        total_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            images = batch['images'].to(device)  # Assuming images key in your dataset

            # Split the images into input sequence and target image
            input_images = images[:, :5, :, :, :]  # Use the first five images as input sequence
            target_image = images[:, 5, :, :, :].squeeze(1)  # Use the sixth image as target
            target_image = target_image.unsqueeze(1)  # Ensure target image has the channel dimension

            optimizer.zero_grad()

            # The model expects a 5D tensor, so no change is needed here
            predicted_image = model(input_images)

            # Calculate loss using SSIM
            loss = 1 - ssim(predicted_image, target_image, data_range=1, size_average=True)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)  # Store the average loss for this epoch

        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

        return loss_history

    # After training, plot the training loss
    plt.figure(figsize=(10, 6))
    plt.plot(loss_history, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.show()


In [None]:
#SSIM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = StormGenerator().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1)
num_epochs = 1
train(model, train_loader, optimizer, device)

Tests for training model

In [27]:
import unittest
from unittest.mock import MagicMock, patch
import torch

class TestTrainingFunction(unittest.TestCase):
    @patch('__main__.ssim', return_value=torch.tensor(0.8, requires_grad=True))  # Ensure SSIM returns a tensor with gradients
    def test_train_initializes_loss_history_and_trains(self, mock_ssim):
        mock_model = MagicMock()
        mock_model.return_value = torch.randn(1, 3, 64, 64, requires_grad=True)  # Mock model output
        
        mock_optimizer = MagicMock()
        mock_device = torch.device('cpu')
        
        # Provide a valid 'storm_id' corresponding to an existing directory within 'root_dir'
        mock_storm_id = '/Users/mk1923/Downloads/Selected_Storms_curated_to_zip/blq'  # Replace with a valid storm ID
        mock_train_loader = StormDataset(root_dir='/Users/mk1923/Downloads/Selected_Storms_curated_to_zip', storm_id=mock_storm_id)
        
        num_epochs = 1
        
        # Run the training function
        with patch('__main__.plt.show'):
            loss_history = train(mock_model, mock_train_loader, mock_optimizer, mock_device, num_epochs)
        
        # Assertions
        self.assertIsInstance(loss_history, list)
        self.assertGreater(len(loss_history), 0)


# Running the test without using unittest.main() to avoid argv issues in Jupyter
if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
    exitcode = _main(fd, parent_sentinel)
                  ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^  File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^    self = reduction.pickle.load(from_parent)
^^^^^^^^^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^
^AttributeError: ^Can't get attribute 'StormDataset' on <module '__main__' (built-in)>^
^^^