# Operator overloading

|       Category       |          Trait             |             Operator              |
| -------------------- | -------------------------- | --------------------------------- |
| Unary operators      | `std::ops::Neg`            | `-x   `                           |
|                      | `std::ops::Not`            | `!x   `                           |
| Arithmetic operators | `std::ops::Add`            | `x + y`                           |
|                      | `std::ops::Sub`            | `x - y`                           |
|                      | `std::ops::Mul`            | `x * y`                           |
|                      | `std::ops::Div`            | `x / y`                           |
|                      | `std::ops::Rem`            | `x % y`                           |
| Bitwise operators    | `std::ops::BitAnd`         | `x & y`                           |
|                      | `std::ops::BitOr `         | `x \| y`                          |
|                      | `std::ops::BitXor`         | `x ^ y`                           |
|                      | `std::ops::Shl   `         | `x << y`                          |
|                      | `std::ops::Shr   `         | `x >> y`                          |
| Compound assignment arithmetic operators | `std::ops::AddAssign`      | `x += y`      |
|                      | `std::ops::MulAssign`      | `x *= y`                          |
|                      | `std::ops::SubAssign`      | `x -= y`                          |
|                      | `std::ops::DivAssign`      | `x /= y`                          |
|                      | `std::ops::RemAssign`      | `x %= y`                          |
| Compound assignment bitwise operators    | `std::ops::BitAndAssign`   | `x &= y`      |
|                      | `std::ops::BitOrAssign `   | `x \| = y`                        |
|                      | `std::ops::BitXorAssign`   | `x ^= y`                          |
|                      | `std::ops::ShlAssign   `   | `x <<= y`                         |
|                      | `std::ops::ShrAssign   `   | `x >>= y`                         |
| Comparison           | `std::cmp::PartialEq   `   | `x == y`, `x != y`                |
|                      | `std::cmp::PartialOrd`     | `x < y,  x <= y,  x > y,  x >= y` |
| Indexing             | `std::ops::Index`          | `x[y],  &x[y]`                    |
|                      | `std::ops::IndexMut`       | `x[y] = z,  &mut x[y]`            |

In [2]:
#[derive(Debug, Clone, Copy)]
struct Complex<T> {
    re: T,
    im: T,
}

In [5]:
let z: Complex<f64> = Complex { re: 1.0, im: 2.0 };

In [5]:
z + z

Error: an implementation of `Add` might be missing for `Complex<f64>`

## Implementing `Add` trait

In [6]:
mod explain {
    trait Add<RHS=Self> {
        type Output;
        fn add(self, rhs: RHS) -> Self::Output;
    }
}

In [6]:
use std::ops::Add;

impl<T> Add for Complex<T> 
where T: Add<Output = T> {
    type Output = Self;   
     
    fn add(self, rhs: Complex<T>) -> Self::Output {
        Complex {
            re: self.re + rhs.re,
            im: self.im + rhs.im,
        }
    }
}

In [7]:
z + z

Complex { re: 2.0, im: 4.0 }

## Implementing `Neg` trait

In [8]:
mod explain {
    trait Neg {
        type Output;
        fn neg(self) -> Self::Output;
    }
}

In [9]:
-z

Error: an implementation of `std::ops::Neg` might be missing for `Complex<f64>`

In [10]:
use std::ops::Neg;

impl<T> Neg for Complex<T>
where T: Neg<Output = T> {
    type Output = Self;
    
    fn neg(self) -> Self::Output {
        Complex {
            re: -self.re,
            im: -self.im,
        }
    }
}

In [11]:
-z

Complex { re: -1.0, im: -2.0 }

## Compound Assignment Operators: `+=`, `-=`, `*=`, `/=`, `%=`

In [13]:
mod explain {
    trait AddAssign<RHS=Self> {
        fn add_assign(&mut self, rhs: RHS);
    }
}

In [12]:
use std::ops::AddAssign;

impl<T> AddAssign for Complex<T>
where T: AddAssign<T> {
    fn add_assign(&mut self, rhs: Complex<T>) {
        self.re += rhs.re;
        self.im += rhs.im;
    }
}

In [15]:
let mut c = Complex { re: 1.0, im: 2.0 };

c += Complex { re: 3.0, im: 4.0 };

c

Complex { re: 4.0, im: 6.0 }

## Equality and Ordering

### PartialEq

* in order to use `==` and `!=` operators, implement `PartialEq` trait

In [None]:
mod explain {
    trait PartialEq<RHS=Self> 
    where RHS: ?Sized
    {
        fn eq(&self, other: &RHS) -> bool;        
        fn ne(&self, other: &RHS) -> bool { !self.eq(other) } // default implementation
    }
}

In [6]:
use std::cmp::PartialEq;

impl<T: PartialEq> PartialEq for Complex<T> {
    fn eq(&self, other: &Complex<T>) -> bool {
        //self.re == other.re && self.im == other.im
        (&self.re, &self.im) == (&other.re, &other.im)
    }
}

Error: conflicting implementations of trait `PartialEq` for type `Complex<_>`

* now we can check for equality using `==` and `!=` operators

In [4]:
z == z

Error: cannot find value `z` in this scope

Error: cannot find value `z` in this scope

In [19]:
z != c

true

#### PartialEq & f64

 * Mathematical definition of an equivalence relation, of which equality is one instance, imposes three requirements. For any values `x` and `y`:
  * If `x == y` is true, then `y == x` must be true as well. In other words, swapping the two sides of an equality comparison doesn’t affect the result.
  * If `x == y` and `y == z`, then it must be the case that `x == z`. Given any chain of values, each equal to the next, each value in the chain is directly equal to every other. Equality is contagious.
  * It must always be true that `x == x`.

 Types `f32` and `f64` are IEEE standard floating-point values. According to that standard, expressions like `0.0/0.0` and others with no appropriate value must produce special not-a-number values, usually referred to as `NaN` values. The standard further requires that a `NaN` value be treated as unequal to every other value—including itself.

In [21]:
assert!(f64::is_nan(0.0 / 0.0));

assert_eq!(0.0/0.0 == 0.0/0.0, false);
assert_eq!(0.0/0.0 != 0.0/0.0, true);

In [26]:
println!("{:?} == {:?} yields {:?}", 0.0/0.0, f64::NAN, 0.0/0.0 == f64::NAN);

NaN == NaN yields false


While Rust’s `==` operator meets the first two requirements for equivalence relations, it clearly doesn’t meet the third when used on IEEE floating-point values. This is called a partial equivalence relation, so Rust uses the name `PartialEq` for the `==` operator’s built-in trait.

### Eq - Full equivalence relation

If you’d prefer your generic code to require a full equivalence relation, you can instead use the `std::cmp::Eq` trait as a bound, which represents a full equivalence relation: if a type implements `Eq`, then `x == x` must be `true` for every value `x` of that type. 

In practice, almost every type that implements `PartialEq` should implement `Eq` as well; `f32` and `f64` are the only types in the standard library that are `PartialEq` but not `Eq`.

The standard library defines `Eq` as an extension of `PartialEq`, adding no new methods:

In [None]:
mod explain {
    trait Eq : PartialEq<Self> {}
}

Implementing `Eq` for `Complex<T>` is straightforward:

In [32]:
impl<T: Eq> Eq for Complex<T> {}

We could also derive `Eq` for generic types that implement `PartialEq`:

In [34]:
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
struct Value<T> {
    value: T
}

Derived implementations on a generic type may depend on the type parameters. With the derive attribute, `Value<i32>` would implement `Eq`, because `i32` does, but `Value<f32>` would implement only `PartialEq`, since `f32` doesn’t implement `Eq`.

### PartialOrd

Any comparison for `f64` or `f32` involving `NaN` will return `false`:

In [38]:
assert_eq!(0.0/0.0 < 0.0/0.0, false);
assert_eq!(0.0/0.0 > 0.0/0.0, false);
assert_eq!(0.0/0.0 <= 0.0/0.0, false);
assert_eq!(0.0/0.0 >= 0.0/0.0, false);

In [39]:
mod explain {
    enum Ordering {
        Less,
        Equal,
        Greater,
    }

    trait PartialOrd<RHS=Self>: PartialEq<RHS> 
    where 
        RHS: ?Sized 
    {
        fn partial_cmp(&self, other: &RHS) -> Option<Ordering>;
        
        fn lt(&self, other: &RHS) -> bool { 
            match self.partial_cmp(other) {
                Some(Ordering::Less) => true,
                _ => false,
            }
        }
        
        fn le(&self, other: &RHS) -> bool { 
            match self.partial_cmp(other) {
                Some(Ordering::Less) | Some(Ordering::Equal) => true,
                _ => false,
            }
        }
        
        fn gt(&self, other: &RHS) -> bool { 
            match self.partial_cmp(other) {
                Some(Ordering::Greater) => true,
                _ => false,
            }
        }
        
        fn ge(&self, other: &RHS) -> bool { 
            match self.partial_cmp(other) {
                Some(Ordering::Greater) | Some(Ordering::Equal) => true,
                _ => false,
            }
        }
    }

    // trait for types that can be totally ordered
    trait Ord: Eq + PartialOrd<Self> {
        fn cmp(&self, other: &Self) -> Ordering;
    }
}

In [7]:
#[derive(PartialEq, Eq, Debug)]
struct Interval<T> {
    lower: T, // inclusive
    upper: T  // exclusive 
}

In [8]:
use std::cmp::{Ordering, PartialOrd};

impl<T: PartialOrd> PartialOrd for Interval<T> {
    fn partial_cmp(&self, other: &Interval<T>) -> Option<Ordering> {
        if self == other {
            Some(Ordering::Equal)
        } else if self.lower >= other.upper {
            Some(Ordering::Greater)
        } else if self.upper <= other.lower {
            Some(Ordering::Less)
        } else {
            None
        }            
    }
}

In [9]:
assert!(Interval { lower: 10, upper: 20 } <  Interval { lower: 20, upper: 40 });
assert!(Interval { lower: 7,  upper: 8  } >= Interval { lower: 0,  upper: 1  });
assert!(Interval { lower: 7,  upper: 8  } <= Interval { lower: 7,  upper: 8  });

// Overlapping intervals aren't ordered with respect to each other.
let left  = Interval { lower: 10, upper: 30 };
let right = Interval { lower: 20, upper: 40 };
assert!(!(left < right));
assert!(!(left >= right));

#### PartialOrd & f64

Floating-point numbers have a `PartialOrd` implementation, because `NaN` is neither greater than, less than, nor equal to any other value.

In [10]:
3.14.partial_cmp(&f64::NAN)

None

In [11]:
f64::NAN.partial_cmp(&f64::NAN)

None

In [14]:
// Sorting of floats

let mut data_ints = vec![1, 42, 665, 88, 9, 0, 2, 3, 4, 5];
data_ints.sort();
data_ints

[0, 1, 2, 3, 4, 5, 9, 42, 88, 665]

In [19]:
let mut vec_f64 = vec![1.0,  42.0, 777.9, 0.0 / 0.0, f64::NAN, 2.0, 3.0, 0.0, 4.0, 5.0];
vec_f64.sort_by(|a, b| f64::total_cmp(a, b));
vec_f64

[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 42.0, 777.9, NaN, NaN]

### Indexing

* In order to use indexing expression `obj[i]` for a given type we have to implement `Index` (and optionally `IndexMut`) trait:

In [2]:
mod explain {
    
    trait Index<Idx> {
        type Output: ?Sized;
        fn index(&self, index: Idx) -> &Self::Output;
    }

    trait IndexMut<Idx>: Index<Idx> {
        fn index_mut(&mut self, index: Idx) -> &mut Self::Output;
    }
}

In [6]:
struct Image<P> {
    width: usize,
    pixels: Vec<P>,
}

impl<P: Default + Copy> Image<P> {

    /// Create a new image with the given dimensions.
    fn new(width: usize, height: usize) -> Self {
        Image {
            width: width,
            pixels: vec![P::default(); width * height],
        }
    }
}

impl<P> std::ops::Index<usize> for Image<P> {
    type Output = [P];
    
    fn index(&self, row: usize) -> &[P] {
        let start = row * self.width;
        &self.pixels[start..start + self.width]
    }
}

impl<P> std::ops::IndexMut<usize> for Image<P> {
    fn index_mut(&mut self, row: usize) -> &mut [P] {
        let start = row * self.width;
        &mut self.pixels[start..start + self.width]
    }
}

In [8]:
let mut bitmap = Image::new(2, 2);

bitmap[0][0] = 1;
bitmap[0][1] = 2;

assert!(bitmap[0][0] == 1);
assert!(bitmap[0][1] == 2);

# Sorting floats

* Common `sort()` method from `std::slice::SliceExt` trait doesn't work for floats because of the `NaN` values:

In [10]:
let mut data = vec![2.1, 3.14, 0.3, 0.271, 0.0, 5.34, 1.0, 0.0];

data.sort();

Error: the trait bound `{float}: Ord` is not satisfied

* When we assume that `NaN` is not present in the slice, we can use `sort_by()` method from `std::slice::SliceExt` trait combined with `partial_cmp()` method from `std::cmp::PartialOrd` trait:

In [26]:
let mut data = vec![2.1, 3.14, 0.3, 0.271, 0.0, std::f64::NAN, 5.34, 1.0, 0.0/0.0, std::f64::consts::PI, std::f64::INFINITY, std::f64::NEG_INFINITY];

data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Less));

println!("{:?}", data);

[-inf, NaN, NaN, 0.0, 0.271, 0.3, 1.0, 2.1, 3.14, 3.141592653589793, 5.34, inf]


## Total ordering for floats

* Since Rust 1.64 we can use `f64::total_cmp()` method. It returns `Ordering` enum which can be used to sort floats.

* The floating point values are totally ordered in the following sequence:

  * negative quiet NaN
  * negative signaling NaN
  * negative infinity
  * negative numbers
  * negative subnormal numbers
  * negative zero
  * positive zero
  * positive subnormal numbers
  * positive numbers
  * positive infinity
  * positive signaling NaN
  * positive quiet NaN.

In [21]:
f64::NAN.total_cmp(&std::f64::consts::PI)

Greater

In [25]:
let mut data = vec![2.1, 3.14, 0.3, 0.271, 0.0, std::f64::NAN, 5.34, 1.0, 0.0/0.0, std::f64::consts::PI, std::f64::INFINITY, std::f64::NEG_INFINITY];

// data.sort_by(|a, b| f64::total_cmp(a, b));
data.sort_by(f64::total_cmp);

println!("{:?}", data);

[-inf, 0.0, 0.271, 0.3, 1.0, 2.1, 3.14, 3.141592653589793, 5.34, inf, NaN, NaN]
