# Implementing ase's crystal algorithm

This notebook shows how to use the `Crystal` package. But the focus here are the components in the background, to enable easier grokking of algorithm behind the creation of single crystals.

In [None]:
import Pkg
Pkg.activate(".")
Pkg.instantiate()

In [None]:
# Pkg.add("LaTeXStrings")

In [None]:
using Plots
using LaTeXStrings
using JSON
using LinearAlgebra
using Test
using Crystal

## Specifying single crystals

In [None]:
crystal_specs = Dict(
    # NaCl structure
    "NaCl" => Dict(
        "symbols" => ["Na", "Cl"],
        "basis" => [[0. 0. 0.], [.5 .5 .5]], # scaled coordinates
        "nr" => 225,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [5.64, 5.64, 5.64, 90, 90, 90]
    ),
    # Al fcc structure
    "Al_fcc" => Dict(
        "symbols" => ["Al"],
        "basis" => [[0. 0. 0.],], # scaled coordinates
        "nr" => 225,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [4.05, 4.05, 4.05, 90, 90, 90]
    ),
    # Fe bcc structure
    "Fe_bcc" => Dict(
        "symbols" => ["Fe"],
        "basis" => [[0. 0. 0.],], # scaled coordinates
        "nr" => 229,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [2.87, 2.87, 2.87, 90, 90, 90]
    ),
    # Mg hcp structure
    "Mg_hcp" => Dict(
        "symbols" => ["Mg"],
        "basis" => [[1/3 2/3 3/4],], # scaled coordinates
        "nr" => 194,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [3.21, 3.21, 5.21, 90, 90, 120]
    ),
    # Diamond structure
    "Diamond" => Dict(
        "symbols" => ["C"],
        "basis" => [[0. 0. 0.],], # scaled coordinates
        "nr" => 227,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [3.57, 3.57, 3.57, 90, 90, 90]
    ),
    # Rutile structure
    "Rutile" => Dict(
        "symbols" => ["Ti", "O"],
        "basis" => [[0. 0. 0.], [.3 .3 0.]], # scaled coordinates
        "nr" => 136,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [4.6, 4.6, 2.95, 90, 90, 90]
    ),
    # CoSb3 skudderudite
    "Skudderudite" => Dict(
        "symbols" => ["Co", "Sb"],
        "basis" => [[.25 .25 .25], [0. .335 .158]], # scaled coordinates
        "nr" => 204,
        "setting" => 1,
        "a_direction" => [1.; 0.; 0.],
        "ab_normal" => [0.; 0.; 1.],
        "cellpar" => [9.04, 9.04, 9.04, 90, 90, 90]
    )
);

## Constructing a single crystal

### High level

Using convenience functions we can create fcc/bcc unit cells like so

In [None]:
Crystal.make_bcc_unitcell("W", 3.4)

In [None]:
Crystal.make_fcc_unitcell("W", 3.4)

If we want to create a unit cell using a basis and spacegroup information we collect the specs first

In [None]:
name = "Skudderudite"

nr = crystal_specs[name]["nr"]
setting = crystal_specs[name]["setting"]
basis = crystal_specs[name]["basis"]
symbols = crystal_specs[name]["symbols"]
a_direction = crystal_specs[name]["a_direction"]
ab_normal = crystal_specs[name]["ab_normal"]
cellpar = crystal_specs[name]["cellpar"];

Using the `make_unitcell` convenience function

In [None]:
crystal = Crystal.make_unitcell(basis, symbols, nr, setting, cellpar,
                       a_direction=a_direction, ab_normal=ab_normal)
println(crystal)

If you just want to create crystals and don't care about the clockwork in the background you don't need to look any further.

### Decomposing the high level functions

#### Collecting symmetry operations

In [None]:
spgs = Crystal.load_spgs()
nr = crystal_specs[name]["nr"]
setting = crystal_specs[name]["setting"]
spg = spgs["$(nr): $(setting)"]

In [None]:
function parse_spg(spg::Dict{String,Any})::Dict{String,Any}
    spg["subtrans"] = [Array{Float64}(reshape(v,3)) for v in spg["subtrans"]]
    spg["translations"] = [Array{Float64}(reshape(v,3)) for v in spg["translations"]]
    spg["rotations"] = [Array{Float64}(hcat(v...)) for v in spg["rotations"]]
    return spg
end

In [None]:
spg = parse_spg(spg)

In [None]:
function get_symops(spg::Dict)
    parities = spg["centrosymmetric"] ? [1,-1] : [1]
    symops = []
    @assert length(spg["rotations"]) == length(spg["translations"])
    for (parity, trans_sub) in Iterators.product(parities, spg["subtrans"])
        for (rot, trans) in Iterators.zip(spg["rotations"], spg["translations"])
            push!(symops, (parity * rot,
                           (trans + trans_sub) .% 1))
        end
    end
    return symops
end

In [None]:
symops = get_symops(spg)
symops[:4]

#### Generating equivalent sites of the basis

In [None]:
function fold(x::T) where T<:Any
    return x < 0 ? x + 1 : x
end

function get_equivalent_sites(basis::Array{Array{Float64,2},1}, symops)
    kinds, sites = [], []

    for (kind, pos) in enumerate(basis)
        for (rot, trans) in symops
            site = (transpose(pos * rot) + trans) .% 1
            site = fold.(site)
            isdifferent = !any([v ≈ site for v in sites])
            if ((length(sites) == 0) | isdifferent)
                push!(sites, site)
                append!(kinds, kind)
            end
        end
    end
    return sites, kinds
end

Using the basis and the symmetry operations we can generate the equivalent sites

In [None]:
basis = crystal_specs[name]["basis"]
sites, kinds = get_equivalent_sites(basis, symops)

#### Computing the cell box vectors

In [None]:
function make_unit_vec(x::Array{T,1}) where T<:Any
    return x / norm(x)
end

function get_coords(a_direction::Array{Float64,1}, ab_normal::Array{Float64,1})
    @assert dot(a_direction, ab_normal) ≈ 0.
    _x = make_unit_vec(a_direction)
    z = make_unit_vec(ab_normal)

    x = _x - dot(_x, ab_normal) * z
    xyz = hcat(x, cross(z,x), z)
    return xyz
end

In [None]:
a_direction = crystal_specs[name]["a_direction"]
ab_normal = crystal_specs[name]["ab_normal"]
xyz = get_coords(a_direction, ab_normal)

In [None]:
function deg2rad(x::T) where T <: Real
    return x * π / 180.
end

function get_cos(x::T) where T <: Real
    return x ≈ 90 ? 0 : cos(deg2rad(x))
end

function get_cell_vectors(cellpar::Array{Float64,1})
    a, b, c, α, β, γ = cellpar
    cos_α = get_cos(α)
    cos_β = get_cos(β)
    cos_γ = get_cos(abs(γ))
    sin_γ = abs(γ) ≈ 90 ? sign(γ) : sin(deg2rad(γ))
    cos_α, cos_β, cos_γ, sin_γ
    
    cy = (cos_α - cos_β * cos_γ) / sin_γ
    abc = hcat([a; 0; 0], b*[cos_γ; sin_γ; 0], c*[cos_β; cy; √(1-cos_β*cos_β-cy*cy)])
    return abc
end

In [None]:
abc = get_cell_vectors(crystal_specs[name]["cellpar"])

In [None]:
cell = abc * xyz

#### Storing everything within a `Crystal.Cell` struct

In [None]:
chemical_symbols, atomic_numbers, masses = Crystal.get_chemical_info()

In [None]:
# storing crystal properties in a `Crystal.Cell` instance
el2atom_map = Dict(el => Crystal.Atom(name=el, mass=masses[el]) for el in keys(masses))

cc = Crystal.CartesianCoords(Float64)
box = Crystal.PrimitiveVectors(cc, A₁=abc[:,1], A₂=abc[:,2], A₃=abc[:,3])

_spg = Crystal.Spacegroup(nr, setting, kinds, sites)

crystal = Crystal.Cell(
    [el2atom_map[symbols[v]] for v in kinds],
    [cell * v for v in  sites],
    box,
    [norm(abc[:,1]), norm(abc[:,2]), norm(abc[:,3])],
    _spg
)

### Plotting a `Crystal.Cell` struct

In [None]:
function plot_crystal(cell::Crystal.Cell;
        default_color::String="blue",
        element_color_map::Dict=Dict{String,String}("Fe" => "blue"),
        default_size::T=50,
        element_size_map::Dict=Dict{String,Any}()
    ) where T <: Real
    
    atoms, coords = cell.atoms, cell.positions 
    
    elements = Set([atom.name for atom in atoms])
    for element in elements
        if !haskey(element_color_map, element)
            element_color_map[element] = default_color
        end
        if !haskey(element_size_map, element)
            element_size_map[element] = default_size
        end
    end
    colors = [element_color_map[atom.name] for atom in atoms]
    sizes = [element_size_map[atom.name] for atom in atoms]

    x = [v[1] for v in coords]
    y = [v[2] for v in coords]
    z = [v[3] for v in coords]
    return @gif for i in range(0, stop=2π, length=100)
        scatter(x, y, z, camera=(10*(1+cos(i)),5),
            markersize=sizes, legend=false, 
            color=colors, aspect_ratio=:equal,
            xlabel=L"x", ylabel=L"y", zlabel=L"z",
            title=string(length(atoms), " atoms of: ", join(elements, ","))
        )
    end
end

In [None]:
element_color_map = Dict("Na" => "purple", "Cl" => "green",
    "Co"=>"pink", "Sb"=>"purple")
element_size_map = Dict("Na"=>20,"Cl"=>10, "Co"=>15, "Sb"=>15)

In [None]:
plot_crystal(crystal, 
    element_color_map=element_color_map,
    element_size_map=element_size_map,
    default_size=5)

## Sanity checking generated positions and cell boxes

Using `ase` as a reference we can directly check if our positions and cell boxes have the expected values. For this we load pre-computed values from disk.

In [None]:
json_ase_crystals = Crystal.load_refs()

### Checking a single crystal

In [None]:
name = "Skudderudite"

In [None]:
json_ase_crystal = json_ase_crystals[name]

In [None]:
ase_crystal = Crystal.parse_json_crystal(json_ase_crystal)

In [None]:
function positions_match_ase(crystal::Crystal.Cell, ase_crystal::Dict)
    return all([p0 ≈ p1 for (p0,p1) in zip(ase_crystal["positions"],crystal.positions)])
end

In [None]:
@assert positions_match_ase(crystal, ase_crystal)

In [None]:
function cell_matches_ase(crystal::Crystal.Cell, ase_crystal::Dict)
    return crystal.box.M ≈ ase_crystal["cell"]
end

In [None]:
@assert cell_matches_ase(crystal, ase_crystal)

### Looping all crystals

All right, looks good so far. Let's check all crystals

In [None]:
crystals = Dict()

for name in keys(crystal_specs)
    @testset "$(name)" begin
        nr = crystal_specs[name]["nr"]
        setting = crystal_specs[name]["setting"]
        basis = crystal_specs[name]["basis"]
        symbols = crystal_specs[name]["symbols"]
        a_direction = crystal_specs[name]["a_direction"]
        ab_normal = crystal_specs[name]["ab_normal"]
        cellpar = crystal_specs[name]["cellpar"]

        crystal = Crystal.make_unitcell(basis, symbols, nr, setting, cellpar,
                               a_direction=a_direction, ab_normal=ab_normal)
        ase_crystal = Crystal.parse_json_crystal(json_ase_crystals[name])
        
        @testset "positions match" begin 
            @test positions_match_ase(crystal, ase_crystal)
        end
        @testset "cell match" begin
            @test cell_matches_ase(crystal, ase_crystal)
        end
        
        crystals[name] = crystal
        
    end
end

## Generating a supercell

By cloning and shifting the atoms of the unitcell along the cell box vectors by integers up to `nx`, `ny` and `nz` we create a *supercell*

In [None]:
name = "Al_fcc"
supercell = Crystal.make_supercell(crystals[name], nx=3, ny=3, nz=3);

In [None]:
plot_crystal(supercell, 
    element_color_map=element_color_map,
    element_size_map=element_size_map,
    default_size=5)

## Storing crystal as json readable by `ase.Atoms.fromdict`

Since one may want to re-use the crystals created with this julia package from python with ase, let's export them.

In [None]:
ds = Dict(name => Crystal.cell2dict(crystals[name])
          for name in keys(crystals))

open("julia-atoms.json","w") do f
    JSON.print(f, ds)
end