In [None]:
#default_exp utils

In [None]:
#exporti
import os
import json
import datetime
from collections import defaultdict
from typing import Union
import torch

In [None]:
#hide
from nbdev.showdoc import show_doc

# Utils

In [None]:
#export
def get_current_datetime_as_string():
    """
    Determines and returns a string containing the current date and time.
    """
    now = datetime.datetime.today()
    return now.strftime('%Y-%m-%d--%H:%M:%S')


def create_dir(
    name:str, # The name of the directory that should be created.
    path:str=".", # The path where the directory should be created.
    prepend_date:bool=False # Whether to preprend the directory name with the date and time of its creation. Ensures unique directory names.
):
    """
    Creates a new directory, optionally prepended with the current datetime. If the directory already exists, then nothing happens. Returns a string that is the path to the directory.
    """
    if prepend_date:
        cdt = get_current_datetime_as_string()
        name = f"{cdt}_{name}"

    dir_path = f"{path}/{name}"

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    return dir_path


def save_dict_as_txt(
    my_dict:dict, # The dictionary that should be saved.
    dir_path:str, # The path where the directory should be saved.
    file_name:str # The name of the txt file that should be created.
):
    """
    Saves a python dictionary as a txt file.
    """
    path = f"{dir_path}/{file_name}"
    if path[-4:] != ".txt":
        path += ".txt"
    with open(path, 'w') as f:
        json.dump(my_dict, f, indent=2)

In [None]:
#export
def cast_to_problem(
    problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
):
    """
    Accepts a problem or a solution object as input and returns a problem. If the input is a problem, then the problem is simply returned without modification.
    If it is a solution, then `solution.problem` is returned.
    """
    assert type(problem_or_solution) != list, type(problem_or_solution)
    try:
        return problem_or_solution.problem
    except:
        return problem_or_solution


def cast_to_solution(
    problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
):
    """
    Accepts a problem or a solution object as input and returns a solution. If the input is a problem, then `problem.trivial_solution`.
    If the input is a solution, then it is simply returned without modification.
    """
    assert type(problem_or_solution) != list, type(problem_or_solution)
    try:
        return problem_or_solution.trivial_solution
    except:
        return problem_or_solution


def cast_to_problems(
    problems_or_solutions:list # A list containing problem and solution objects.
):
    """
    Accepts as input a list containing problem and solutions object. Returns a list that only contains problem objects, where the solution objects have been
    transformed into problems via `solution.problem`.
    """
    return [cast_to_problem(p_or_s) for p_or_s in problems_or_solutions]


def cast_to_solutions(
    problems_or_solutions:list # A list containing problem and solution objects.
):
    """
    Accepts as input a list containing problem and solutions object. Returns a list that only contains solution objects, where the problem objects have been
    transformed into solutions via `problem.trivial_solution`.
    """
    return [cast_to_solution(p_or_s) for p_or_s in problems_or_solutions]

In [None]:
#export
def get_dataloader(
    dataset:"dl4to.datasets.TopoDataset", # The dataset for which the dataloader should be created.
    batch_size:int=1, # The batch size for the dataloader.
    shuffle:bool=True, # Whether the dataloader should shuffle the samples.
    num_workers:int=0 # The number of GPU workers, if trained on a GPU.
):
    """
    Returns a `torch.utils.data.DataLoader` object for `dataset`.
    """
    return torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=lambda batch: list(zip(*batch)),
        shuffle=shuffle,
        num_workers=num_workers
    )

In [None]:
#export
def get_σ_vm(
    σ:torch.Tensor, # The stress tensor from which the von Mises stresses should be computed.
    ε:float=1e-9 # A small value that ensures numerically stable results.
):
    """
    Calculates the von Mises stresses from the 9-channel stress tensor `σ` and returns them in a 1-channel `torch.Tensor` object.
    """
    if len(σ.shape) == 4:
        assert σ.shape[0] == 9
        σ_vm = (.5 * ((σ[0] - σ[4]) ** 2
                    + (σ[4] - σ[8]) ** 2
                    + (σ[8] - σ[0]) ** 2
                    + 6 * (σ[1] ** 2 + σ[2] ** 2 + σ[5] ** 2)
                      ) + ε
               ) ** .5
        return σ_vm.unsqueeze(0)
    if len(σ.shape) == 5:
        assert σ.shape[1] == 9
        σ_vm = (.5 * ((σ[:,0] - σ[:,4]) ** 2
                    + (σ[:,4] - σ[:,8]) ** 2
                    + (σ[:,8] - σ[:,0]) ** 2
                    + 6 * (σ[:,1] ** 2 + σ[:,2] ** 2 + σ[:,5] ** 2)
                      ) + ε
               ) ** .5
        return σ_vm.unsqueeze(1)
    else:
        raise ValueError("The shape of σ must either be 4 or 5 channels.")

In [None]:
#hide
from dl4to.problem import Problem
from dl4to.solution import Solution
from dl4to.datasets import BasicDataset

In [None]:
%%time
#hide

def test_that__cast_to_solutions__returns_only_solutions():
    list_of_problems_or_solutions = [
        BasicDataset().ledge(),
        BasicDataset().cantilever().get_trivial_solution(),
        BasicDataset().fork(),
        BasicDataset().wheel().get_trivial_solution(),
    ]

    for problem in cast_to_problems(list_of_problems_or_solutions):
        assert type(problem) == Problem, type(problem)


    for solution in cast_to_solutions(list_of_problems_or_solutions):
        assert type(solution) == Solution, type(solution)


test_that__cast_to_solutions__returns_only_solutions()

CPU times: user 30.3 ms, sys: 3.32 ms, total: 33.6 ms
Wall time: 49.7 ms


In [None]:
%%time
#hide

def test_that_diagonal_matrix_σ_results_in_zero_vm_stresses():
    σ = torch.eye(3)
    assert get_σ_vm(σ.flatten()) == 0.

test_that_diagonal_matrix_σ_results_in_zero_vm_stresses()

CPU times: user 2.7 ms, sys: 740 µs, total: 3.44 ms
Wall time: 2.57 ms
